View Javadoc
1   /*
2    * The baseCode project
3    * 
4    * Copyright (c) 2008-2019 University of British Columbia
5    * 
6    * Licensed under the Apache License, Version 2.0 (the "License");
7    * you may not use this file except in compliance with the License.
8    * You may obtain a copy of the License at
9    *
10   *       http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   *
18   */
19  
20  package ubic.basecode.dataStructure.matrix;
21  
22  import java.util.ArrayList;
23  import java.util.Collection;
24  import java.util.List;
25  import java.util.function.DoubleFunction;
26  
27  import cern.colt.list.BooleanArrayList;
28  import ubic.basecode.math.Constants;
29  import cern.colt.list.DoubleArrayList;
30  import cern.colt.matrix.DoubleMatrix1D;
31  import cern.colt.matrix.DoubleMatrix2D;
32  import cern.colt.matrix.impl.DenseDoubleMatrix2D;
33  
34  /**
35   * @author Paul
36   */
37  public class MatrixUtil {
38  
39      /**
40       * @param  d
41       * @return   true if any of the values are very close to zero.
42       */
43      public static boolean containsNearlyZeros( DoubleMatrix1D d ) {
44          for ( int i = 0; i < d.size(); i++ ) {
45              if ( Math.abs( d.getQuick( i ) ) < Constants.SMALL ) return true;
46          }
47          return false;
48      }
49  
50      /**
51       * Extract the diagonal from a matrix.
52       * 
53       * @param  matrix
54       * @return
55       */
56      public static DoubleMatrix1D diagonal( DoubleMatrix2D matrix ) {
57  
58          int mindim = Math.min( matrix.rows(), matrix.columns() );
59          DoubleMatrix1D result = new DenseDoubleMatrix1D( mindim );
60          for ( int i = mindim; --i >= 0; ) {
61              result.set( i, matrix.getQuick( i, i ) );
62          }
63          return result;
64      }
65  
66      /**
67       * @param  n
68       * @param  indexToDrop
69       * @return
70       */
71      public static DoubleMatrix2D dropColumn( DoubleMatrix2D n, int indexToDrop ) {
72          int columns = n.columns() - 1;
73          if ( columns < 0 ) throw new IllegalArgumentException( "Must leave some columns" );
74          DoubleMatrix2D res = new DenseDoubleMatrix2D( n.rows(), columns );
75          int k = 0;
76          for ( int j = 0; j < n.columns(); j++ ) {
77              if ( indexToDrop == j ) {
78                  continue;
79              }
80              for ( int i = 0; i < n.rows(); i++ ) {
81                  res.set( i, k, n.getQuick( i, j ) );
82              }
83              k++;
84          }
85          return res;
86      }
87  
88      /**
89       * @param n
90       * @param droppedColumns
91       */
92      public static DoubleMatrix2D dropColumns( DoubleMatrix2D n, Collection<Integer> droppedColumns ) {
93          int columns = n.columns() - droppedColumns.size();
94          if ( columns < 0 ) throw new IllegalArgumentException( "Must leave some columns" );
95          DoubleMatrix2D res = new DenseDoubleMatrix2D( n.rows(), columns );
96          int k = 0;
97          for ( int j = 0; j < n.columns(); j++ ) {
98              if ( droppedColumns.contains( j ) ) {
99                  continue;
100             }
101             for ( int i = 0; i < n.rows(); i++ ) {
102                 res.set( i, k, n.getQuick( i, j ) );
103             }
104             k++;
105         }
106         return res;
107     }
108 
109     /**
110      * Makes a copy
111      * 
112      * @param  list
113      * @return
114      */
115     public static DoubleMatrix1D fromList( DoubleArrayList list ) {
116         DoubleMatrix1D r = new DenseDoubleMatrix1D( list.size() );
117         for ( int i = 0; i < list.size(); i++ ) {
118             r.setQuick( i, list.getQuick( i ) );
119         }
120         return r;
121     }
122 
123     /**
124      * @param           <R>
125      * @param           <C>
126      * @param           <V>
127      * @param  matrix
128      * @param  rowIndex
129      * @param  colIndex
130      * @return
131      */
132     public static <R, C, V> V getObject( Matrix2D<R, C, V> matrix, int rowIndex, int colIndex ) {
133         if ( ObjectMatrix.class.isAssignableFrom( matrix.getClass() ) ) {
134             return ( ( ObjectMatrix<R, C, V> ) matrix ).get( rowIndex, colIndex );
135         } else if ( matrix instanceof PrimitiveMatrix<?, ?, ?> ) {
136             return ( ( PrimitiveMatrix<R, C, V> ) matrix ).getObject( rowIndex, colIndex );
137         } else {
138             throw new UnsupportedOperationException();
139         }
140     }
141 
142     /**
143      * @param           <R>
144      * @param           <C>
145      * @param           <V>
146      * @param  matrix
147      * @param  rowIndex
148      * @return
149      */
150     public static <R, C, V> V[] getRow( Matrix2D<R, C, V> matrix, int rowIndex ) {
151         if ( ObjectMatrix.class.isAssignableFrom( matrix.getClass() ) ) {
152             return ( ( ObjectMatrix<R, C, V> ) matrix ).getRow( rowIndex );
153         } else if ( matrix instanceof PrimitiveMatrix<?, ?, ?> ) {
154             return ( ( PrimitiveMatrix<R, C, V> ) matrix ).getRowObj( rowIndex );
155         } else {
156             throw new UnsupportedOperationException();
157         }
158     }
159 
160     /**
161      * @param source the source of information about missing values
162      * @param target the target where we want to convert values to missing
163      */
164     public static void maskMissing( DoubleMatrix2D source, DoubleMatrix2D target ) {
165         source.checkShape( target );
166 
167         for ( int i = 0; i < source.rows(); i++ ) {
168             for ( int j = 0; j < source.columns(); j++ ) {
169                 if ( Double.isNaN( source.getQuick( i, j ) ) ) {
170                     target.set( i, j, Double.NaN );
171                 }
172             }
173         }
174 
175     }
176 
177     public static DoubleMatrix1D multWithMissing( DoubleMatrix1D a, DoubleMatrix2D b ) {
178         return multWithMissing( a.like2D( 1, a.size() ).assign( new double[][] { a.toArray() } ), b ).viewRow( 0 );
179     }
180 
181     /**
182      * @param  a
183      * @param  b
184      * @return
185      */
186     public static DoubleMatrix1D multWithMissing( DoubleMatrix2D a, DoubleMatrix1D b ) {
187         int m = a.rows();
188         int n = a.columns();
189 
190         if ( b.size() != a.columns() ) {
191             throw new IllegalArgumentException();
192         }
193 
194         DoubleMatrix1D C = new DenseDoubleMatrix1D( m );
195         C.assign( 0.0 );
196 
197         for ( int j = 0; j < m; j++ ) {
198             double s = 0.0;
199             for ( int k = 0; k < n; k++ ) {
200                 double aval = a.getQuick( j, k );
201                 double bval = b.getQuick( k );
202                 if ( Double.isNaN( aval ) || Double.isNaN( bval ) ) {
203                     continue;
204                 }
205                 s += aval * bval;
206             }
207             C.setQuick( j, s + C.getQuick( j ) );
208         }
209 
210         return C;
211     }
212 
213     /**
214      * Multiple two matrices, tolerate missing values.
215      * 
216      * @param  a
217      * @param  b
218      * @return
219      */
220     public static DoubleMatrix2D multWithMissing( DoubleMatrix2D a, DoubleMatrix2D b ) {
221         int m = a.rows();
222         int n = a.columns();
223         int p = b.columns();
224 
225         if ( b.rows() != a.columns() ) {
226             throw new IllegalArgumentException( "Nonconformant matrices: " + b.rows() + " != " + a.columns() );
227         }
228 
229         DoubleMatrix2D C = new DenseDoubleMatrix2D( m, p );
230         C.assign( 0.0 );
231         for ( int i = 0; i < p; i++ ) {
232             for ( int j = 0; j < m; j++ ) {
233                 double s = 0.0;
234                 for ( int k = 0; k < n; k++ ) {
235                     double aval = a.getQuick( j, k );
236                     double bval = b.getQuick( k, i );
237                     if ( Double.isNaN( aval ) || Double.isNaN( bval ) ) {
238                         continue;
239                     }
240                     s += aval * bval;
241                 }
242                 C.setQuick( j, i, s + C.getQuick( j, i ) );
243             }
244         }
245         return C;
246     }
247 
248     public static List<Integer> notNearlyZeroIndices( DoubleMatrix1D d ) {
249         List<Integer> result = new ArrayList<>();
250         for ( int i = 0; i < d.size(); i++ ) {
251             if ( Math.abs( d.getQuick( i ) ) > Constants.SMALL ) result.add( i );
252         }
253         return result;
254     }
255 
256     /**
257      * @param  data
258      * @return      a copy of the data with missing or infinite values removed (might be empty!)
259      */
260     public static DoubleMatrix1D removeMissingOrInfinite(DoubleMatrix1D data ) {
261         int sizeWithoutMissingValues = sizeWithoutMissingValues( data );
262         if ( sizeWithoutMissingValues == data.size() ) return data;
263         DoubleMatrix1D r = new DenseDoubleMatrix1D( sizeWithoutMissingValues );
264         double[] elements = data.toArray();
265         int size = data.size();
266         int j = 0;
267         for ( int i = 0; i < size; i++ ) {
268             if ( Double.isNaN( elements[i] ) || Double.isInfinite( elements[i] ) ) {
269                 continue;
270             }
271             r.set( j++, elements[i] );
272         }
273         return r;
274     }
275 
276 
277     /**
278      * @param x
279      * @return a copy of x with missing values removed
280      */
281     public static final DoubleMatrix1D removeMissing(DoubleMatrix1D x) {
282         if (x.size() == 0) return x.copy();
283         BooleanArrayList ok = new BooleanArrayList(x.size());
284 
285         for (int i = 0; i < x.size(); i++) {
286             double a = x.getQuick(i);
287             ok.add(!(Double.isNaN(a)));
288         }
289 
290         return stripNonOK(x, ok);
291     }
292 
293     /**
294      * Remove values from data corresponding to missing values in reference.
295      * 
296      * @param  reference
297      * @param  data
298      * @return
299      */
300     public static DoubleMatrix1D removeMissingOrInfinite(DoubleMatrix1D reference, DoubleMatrix1D data ) {
301         if ( data.size() != reference.size() ) throw new IllegalArgumentException( "Reference and data must have same size" );
302         int sizeWithoutMissingValues = sizeWithoutMissingValues( reference );
303         if ( sizeWithoutMissingValues == reference.size() ) return data; // no missing values.
304         DoubleMatrix1D r = new DenseDoubleMatrix1D( sizeWithoutMissingValues );
305         double[] elements = data.toArray();
306         double[] refels = reference.toArray();
307         int size = data.size();
308         int j = 0;
309         for ( int i = 0; i < size; i++ ) {
310             if ( Double.isNaN( refels[i] ) || Double.isInfinite( refels[i] ) ) {
311                 continue;
312             }
313             r.set( j++, elements[i] );
314         }
315         return r;
316     }
317 
318 
319 
320 
321     public static final DoubleMatrix1D stripNegative(DoubleMatrix1D x) {
322         if (x.size() == 0) return x.copy();
323         BooleanArrayList ok = new BooleanArrayList(x.size());
324 
325         for (int i = 0; i < x.size(); i++) {
326             double a = x.getQuick(i);
327             ok.add(a >= 0.0);
328         }
329 
330         return stripNonOK(x, ok);
331     }
332 
333     public static final DoubleMatrix1D stripByCriterion(DoubleMatrix1D x, DoubleFunction<Boolean> criterion) {
334         DoubleArrayList okvals = new DoubleArrayList();
335         for (int i = 0; i < x.size(); i++) {
336             if (criterion.apply(x.get(i))) {
337                 okvals.add(x.get(i));
338             }
339         }
340         DoubleMatrix1D answer = new cern.colt.matrix.impl.DenseDoubleMatrix1D(okvals.size());
341         for (int i = 0; i < answer.size(); i++) {
342             answer.set(i, okvals.get(i));
343         }
344         return answer;
345     }
346 
347 
348     /**
349      * Compute the conjuction (logical 'and') of two boolean vectors
350      *
351      * @param a
352      * @param b
353      * @return conjunction of a and b
354      */
355     public static final BooleanArrayList conjunction(BooleanArrayList a, BooleanArrayList b) {
356         assert a.size() == b.size();
357         BooleanArrayList answer = new BooleanArrayList();
358         for (int i = 0; i < a.size(); i++) {
359             answer.add(a.get(i) && b.get(i));
360         }
361         return answer;
362     }
363 
364     /**
365      * @param x         vector to be operated on
366      * @param criterion a function that returns a boolean if a double matches the desired criteria
367      * @return booleans indicating which values in x match the criterion func
368      */
369     public static BooleanArrayList matchingCriteria(DoubleMatrix1D x, DoubleFunction<Boolean> criterion) {
370         BooleanArrayList ok = new BooleanArrayList(x.size());
371         for (int i = 0; i < x.size(); i++) {
372             double a = x.getQuick(i);
373             ok.add(criterion.apply(a));
374         }
375         return ok;
376     }
377 
378 
379     /**
380      * @param x         vector to be operated on
381      * @param criterion criterion used to test values if they should be acted on.
382      * @param action    function applied to values if they match the criterion
383      * @return copy of x in which the values meeting the criterion have been replaced with the return value of action, otherwise unchanged from the original x
384      */
385     public static DoubleMatrix1D applyToIndicesMatchingCriteria(DoubleMatrix1D
386                                                                         x, DoubleFunction<Boolean> criterion, DoubleFunction<Double> action) {
387 
388         if (x.size() == 0) return x.copy();
389         DoubleMatrix1D result = new cern.colt.matrix.impl.DenseDoubleMatrix1D(x.size());
390 
391         for (int i = 0; i < x.size(); i++) {
392             double a = x.getQuick(i);
393             if (criterion.apply(a)) {
394                 result.set(i, action.apply(a));
395             } else {
396                 result.set(i, a);
397             }
398         }
399         return result;
400     }
401 
402     /**
403      * @param x  a vector of values to be filtered
404      * @param ok a list of booleans defining which values are "ok".
405      * @return A copy of x that has the non-ok values removed.
406      */
407     public static final DoubleMatrix1D stripNonOK(DoubleMatrix1D x, BooleanArrayList ok) {
408 
409         assert ok.size() == x.size();
410 
411         DoubleArrayList okvals = new DoubleArrayList();
412         for (int i = 0; i < x.size(); i++) {
413             if (ok.get(i)) {
414                 okvals.add(x.get(i));
415             }
416         }
417         DoubleMatrix1D answer = new cern.colt.matrix.impl.DenseDoubleMatrix1D(okvals.size());
418 
419         for (int i = 0; i < answer.size(); i++) {
420             answer.set(i, okvals.get(i));
421         }
422         return answer;
423     }
424 
425     /**
426      * Perform the awkward operation of substituting certain values in a vector from values in another vector.
427      *
428      * @param x            vector to be operated on in place (it will be modified).
429      * @param toReplace    boolean indicators of same length of x, 'true' indicates a value to be replaced in x with a value from 'replacements'; 'false' will be unmodified in x.
430      * @param replacements the replacements, in order, to be substituted in x where toReplace(i) is true. This vector can be shorter than x.
431      */
432     public static void replaceValues(DoubleMatrix1D x, BooleanArrayList toReplace, DoubleMatrix1D replacements) {
433 
434         if (toReplace.size() != x.size()) {
435             throw new IllegalArgumentException("replacements and x must be the same size");
436         }
437         if (replacements.size() > x.size()) {
438             throw new IllegalArgumentException("Replacements must not outnumber the target");
439         }
440 
441         int numToReplace = 0;
442         for (int i = 0; i < toReplace.size(); i++) {
443             if (toReplace.get(i)) {
444                 numToReplace++;
445             }
446         }
447 
448         if (numToReplace != replacements.size()) {
449             throw new IllegalArgumentException("Number of replacements has to match the nubmer of true values in toReplace");
450         }
451 
452         int j = 0;
453         for (int i = 0; i < x.size(); i++) {
454             if (toReplace.get(i)) {
455                 x.set(i, replacements.get(j));
456                 j++;
457             }
458         }
459 
460     }
461 
462 
463     public static DoubleMatrix1D select( DoubleMatrix1D v, Collection<Integer> selected ) {
464         DoubleMatrix1D result = new DenseDoubleMatrix1D( selected.size() );
465         int k = 0;
466         for ( int i = 0; i < v.size(); i++ ) {
467             if ( selected.contains( i ) ) {
468                 result.set( k, v.getQuick( i ) );
469                 k++;
470             }
471         }
472         return result;
473     }
474 
475     public static DoubleMatrix2D selectColumns( DoubleMatrix2D n, Collection<Integer> selected ) {
476         int ncols = selected.size();
477         DoubleMatrix2D res = new DenseDoubleMatrix2D( n.rows(), ncols );
478         int k = 0;
479         for ( int j = 0; j < n.columns(); j++ ) {
480             if ( !selected.contains( j ) ) {
481                 continue;
482             }
483             for ( int i = 0; i < n.rows(); i++ ) {
484                 res.set( i, k, n.getQuick( i, j ) );
485                 i++;
486             }
487         }
488         return res;
489     }
490 
491     /**
492      * @param  n        square matrix
493      * @param  selected
494      * @return
495      */
496     public static DoubleMatrix2D selectColumnsAndRows( DoubleMatrix2D n, Collection<Integer> selected ) {
497         if ( n.rows() != n.columns() ) {
498             throw new IllegalArgumentException( "must be a square matrix" );
499         }
500 
501         if ( selected.isEmpty() ) {
502             throw new IllegalArgumentException( "must select more than one" );
503         }
504 
505         int columns = selected.size();
506 
507         if ( columns == n.columns() ) {
508             return n;
509         }
510 
511         if ( columns < 0 ) throw new IllegalArgumentException( "Must leave some columns" );
512         DoubleMatrix2D res = new DenseDoubleMatrix2D( columns, columns );
513         int k = 0;
514         for ( int j = 0; j < n.columns(); j++ ) {
515             if ( !selected.contains( j ) ) {
516                 continue;
517             }
518             int m = 0;
519             for ( int i = 0; i < n.rows(); i++ ) {
520                 if ( !selected.contains( i ) ) {
521                     continue;
522                 }
523                 res.set( m, k, n.getQuick( i, j ) );
524                 m++;
525             }
526             k++;
527         }
528         return res;
529     }
530 
531     public static DoubleMatrix2D selectRows( DoubleMatrix2D n, Collection<Integer> selected ) {
532         int nrows = selected.size();
533         DoubleMatrix2D res = new DenseDoubleMatrix2D( nrows, n.columns() );
534         for ( int j = 0; j < n.columns(); j++ ) {
535             int m = 0;
536             for ( int i = 0; i < n.rows(); i++ ) {
537                 if ( !selected.contains( i ) ) {
538                     continue;
539                 }
540                 res.set( m, j, n.getQuick( i, j ) );
541                 m++;
542             }
543         }
544         return res;
545     }
546 
547     public static int sizeWithoutMissingValues( DoubleMatrix1D list ) {
548 
549         int size = 0;
550         for ( int i = 0; i < list.size(); i++ ) {
551             double v = list.getQuick( i );
552             if ( !Double.isNaN( v ) && !Double.isInfinite( v ) ) {
553                 size++;
554             }
555         }
556         return size;
557     }
558 
559     /**
560      * Makes a copy
561      * 
562      * @param  vector
563      * @return
564      */
565     public static DoubleArrayList toList( DoubleMatrix1D vector ) {
566         return new DoubleArrayList( vector.toArray() );
567     }
568 
569 }