1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package ubic.basecode.dataStructure.matrix;
21
22 import java.util.ArrayList;
23 import java.util.Collection;
24 import java.util.List;
25 import java.util.function.DoubleFunction;
26
27 import cern.colt.list.BooleanArrayList;
28 import ubic.basecode.math.Constants;
29 import cern.colt.list.DoubleArrayList;
30 import cern.colt.matrix.DoubleMatrix1D;
31 import cern.colt.matrix.DoubleMatrix2D;
32 import cern.colt.matrix.impl.DenseDoubleMatrix2D;
33
34
35
36
37 public class MatrixUtil {
38
39
40
41
42
43 public static boolean containsNearlyZeros( DoubleMatrix1D d ) {
44 for ( int i = 0; i < d.size(); i++ ) {
45 if ( Math.abs( d.getQuick( i ) ) < Constants.SMALL ) return true;
46 }
47 return false;
48 }
49
50
51
52
53
54
55
56 public static DoubleMatrix1D diagonal( DoubleMatrix2D matrix ) {
57
58 int mindim = Math.min( matrix.rows(), matrix.columns() );
59 DoubleMatrix1D result = new DenseDoubleMatrix1D( mindim );
60 for ( int i = mindim; --i >= 0; ) {
61 result.set( i, matrix.getQuick( i, i ) );
62 }
63 return result;
64 }
65
66
67
68
69
70
71 public static DoubleMatrix2D dropColumn( DoubleMatrix2D n, int indexToDrop ) {
72 int columns = n.columns() - 1;
73 if ( columns < 0 ) throw new IllegalArgumentException( "Must leave some columns" );
74 DoubleMatrix2D res = new DenseDoubleMatrix2D( n.rows(), columns );
75 int k = 0;
76 for ( int j = 0; j < n.columns(); j++ ) {
77 if ( indexToDrop == j ) {
78 continue;
79 }
80 for ( int i = 0; i < n.rows(); i++ ) {
81 res.set( i, k, n.getQuick( i, j ) );
82 }
83 k++;
84 }
85 return res;
86 }
87
88
89
90
91
92 public static DoubleMatrix2D dropColumns( DoubleMatrix2D n, Collection<Integer> droppedColumns ) {
93 int columns = n.columns() - droppedColumns.size();
94 if ( columns < 0 ) throw new IllegalArgumentException( "Must leave some columns" );
95 DoubleMatrix2D res = new DenseDoubleMatrix2D( n.rows(), columns );
96 int k = 0;
97 for ( int j = 0; j < n.columns(); j++ ) {
98 if ( droppedColumns.contains( j ) ) {
99 continue;
100 }
101 for ( int i = 0; i < n.rows(); i++ ) {
102 res.set( i, k, n.getQuick( i, j ) );
103 }
104 k++;
105 }
106 return res;
107 }
108
109
110
111
112
113
114
115 public static DoubleMatrix1D fromList( DoubleArrayList list ) {
116 DoubleMatrix1D r = new DenseDoubleMatrix1D( list.size() );
117 for ( int i = 0; i < list.size(); i++ ) {
118 r.setQuick( i, list.getQuick( i ) );
119 }
120 return r;
121 }
122
123
124
125
126
127
128
129
130
131
132 public static <R, C, V> V getObject( Matrix2D<R, C, V> matrix, int rowIndex, int colIndex ) {
133 if ( ObjectMatrix.class.isAssignableFrom( matrix.getClass() ) ) {
134 return ( ( ObjectMatrix<R, C, V> ) matrix ).get( rowIndex, colIndex );
135 } else if ( matrix instanceof PrimitiveMatrix<?, ?, ?> ) {
136 return ( ( PrimitiveMatrix<R, C, V> ) matrix ).getObject( rowIndex, colIndex );
137 } else {
138 throw new UnsupportedOperationException();
139 }
140 }
141
142
143
144
145
146
147
148
149
150 public static <R, C, V> V[] getRow( Matrix2D<R, C, V> matrix, int rowIndex ) {
151 if ( ObjectMatrix.class.isAssignableFrom( matrix.getClass() ) ) {
152 return ( ( ObjectMatrix<R, C, V> ) matrix ).getRow( rowIndex );
153 } else if ( matrix instanceof PrimitiveMatrix<?, ?, ?> ) {
154 return ( ( PrimitiveMatrix<R, C, V> ) matrix ).getRowObj( rowIndex );
155 } else {
156 throw new UnsupportedOperationException();
157 }
158 }
159
160
161
162
163
164 public static void maskMissing( DoubleMatrix2D source, DoubleMatrix2D target ) {
165 source.checkShape( target );
166
167 for ( int i = 0; i < source.rows(); i++ ) {
168 for ( int j = 0; j < source.columns(); j++ ) {
169 if ( Double.isNaN( source.getQuick( i, j ) ) ) {
170 target.set( i, j, Double.NaN );
171 }
172 }
173 }
174
175 }
176
177 public static DoubleMatrix1D multWithMissing( DoubleMatrix1D a, DoubleMatrix2D b ) {
178 return multWithMissing( a.like2D( 1, a.size() ).assign( new double[][] { a.toArray() } ), b ).viewRow( 0 );
179 }
180
181
182
183
184
185
186 public static DoubleMatrix1D multWithMissing( DoubleMatrix2D a, DoubleMatrix1D b ) {
187 int m = a.rows();
188 int n = a.columns();
189
190 if ( b.size() != a.columns() ) {
191 throw new IllegalArgumentException();
192 }
193
194 DoubleMatrix1D C = new DenseDoubleMatrix1D( m );
195 C.assign( 0.0 );
196
197 for ( int j = 0; j < m; j++ ) {
198 double s = 0.0;
199 for ( int k = 0; k < n; k++ ) {
200 double aval = a.getQuick( j, k );
201 double bval = b.getQuick( k );
202 if ( Double.isNaN( aval ) || Double.isNaN( bval ) ) {
203 continue;
204 }
205 s += aval * bval;
206 }
207 C.setQuick( j, s + C.getQuick( j ) );
208 }
209
210 return C;
211 }
212
213
214
215
216
217
218
219
220 public static DoubleMatrix2D multWithMissing( DoubleMatrix2D a, DoubleMatrix2D b ) {
221 int m = a.rows();
222 int n = a.columns();
223 int p = b.columns();
224
225 if ( b.rows() != a.columns() ) {
226 throw new IllegalArgumentException( "Nonconformant matrices: " + b.rows() + " != " + a.columns() );
227 }
228
229 DoubleMatrix2D C = new DenseDoubleMatrix2D( m, p );
230 C.assign( 0.0 );
231 for ( int i = 0; i < p; i++ ) {
232 for ( int j = 0; j < m; j++ ) {
233 double s = 0.0;
234 for ( int k = 0; k < n; k++ ) {
235 double aval = a.getQuick( j, k );
236 double bval = b.getQuick( k, i );
237 if ( Double.isNaN( aval ) || Double.isNaN( bval ) ) {
238 continue;
239 }
240 s += aval * bval;
241 }
242 C.setQuick( j, i, s + C.getQuick( j, i ) );
243 }
244 }
245 return C;
246 }
247
248 public static List<Integer> notNearlyZeroIndices( DoubleMatrix1D d ) {
249 List<Integer> result = new ArrayList<>();
250 for ( int i = 0; i < d.size(); i++ ) {
251 if ( Math.abs( d.getQuick( i ) ) > Constants.SMALL ) result.add( i );
252 }
253 return result;
254 }
255
256
257
258
259
260 public static DoubleMatrix1D removeMissingOrInfinite(DoubleMatrix1D data ) {
261 int sizeWithoutMissingValues = sizeWithoutMissingValues( data );
262 if ( sizeWithoutMissingValues == data.size() ) return data;
263 DoubleMatrix1D r = new DenseDoubleMatrix1D( sizeWithoutMissingValues );
264 double[] elements = data.toArray();
265 int size = data.size();
266 int j = 0;
267 for ( int i = 0; i < size; i++ ) {
268 if ( Double.isNaN( elements[i] ) || Double.isInfinite( elements[i] ) ) {
269 continue;
270 }
271 r.set( j++, elements[i] );
272 }
273 return r;
274 }
275
276
277
278
279
280
281 public static final DoubleMatrix1D removeMissing(DoubleMatrix1D x) {
282 if (x.size() == 0) return x.copy();
283 BooleanArrayList ok = new BooleanArrayList(x.size());
284
285 for (int i = 0; i < x.size(); i++) {
286 double a = x.getQuick(i);
287 ok.add(!(Double.isNaN(a)));
288 }
289
290 return stripNonOK(x, ok);
291 }
292
293
294
295
296
297
298
299
300 public static DoubleMatrix1D removeMissingOrInfinite(DoubleMatrix1D reference, DoubleMatrix1D data ) {
301 if ( data.size() != reference.size() ) throw new IllegalArgumentException( "Reference and data must have same size" );
302 int sizeWithoutMissingValues = sizeWithoutMissingValues( reference );
303 if ( sizeWithoutMissingValues == reference.size() ) return data;
304 DoubleMatrix1D r = new DenseDoubleMatrix1D( sizeWithoutMissingValues );
305 double[] elements = data.toArray();
306 double[] refels = reference.toArray();
307 int size = data.size();
308 int j = 0;
309 for ( int i = 0; i < size; i++ ) {
310 if ( Double.isNaN( refels[i] ) || Double.isInfinite( refels[i] ) ) {
311 continue;
312 }
313 r.set( j++, elements[i] );
314 }
315 return r;
316 }
317
318
319
320
321 public static final DoubleMatrix1D stripNegative(DoubleMatrix1D x) {
322 if (x.size() == 0) return x.copy();
323 BooleanArrayList ok = new BooleanArrayList(x.size());
324
325 for (int i = 0; i < x.size(); i++) {
326 double a = x.getQuick(i);
327 ok.add(a >= 0.0);
328 }
329
330 return stripNonOK(x, ok);
331 }
332
333 public static final DoubleMatrix1D stripByCriterion(DoubleMatrix1D x, DoubleFunction<Boolean> criterion) {
334 DoubleArrayList okvals = new DoubleArrayList();
335 for (int i = 0; i < x.size(); i++) {
336 if (criterion.apply(x.get(i))) {
337 okvals.add(x.get(i));
338 }
339 }
340 DoubleMatrix1D answer = new cern.colt.matrix.impl.DenseDoubleMatrix1D(okvals.size());
341 for (int i = 0; i < answer.size(); i++) {
342 answer.set(i, okvals.get(i));
343 }
344 return answer;
345 }
346
347
348
349
350
351
352
353
354
355 public static final BooleanArrayList conjunction(BooleanArrayList a, BooleanArrayList b) {
356 assert a.size() == b.size();
357 BooleanArrayList answer = new BooleanArrayList();
358 for (int i = 0; i < a.size(); i++) {
359 answer.add(a.get(i) && b.get(i));
360 }
361 return answer;
362 }
363
364
365
366
367
368
369 public static BooleanArrayList matchingCriteria(DoubleMatrix1D x, DoubleFunction<Boolean> criterion) {
370 BooleanArrayList ok = new BooleanArrayList(x.size());
371 for (int i = 0; i < x.size(); i++) {
372 double a = x.getQuick(i);
373 ok.add(criterion.apply(a));
374 }
375 return ok;
376 }
377
378
379
380
381
382
383
384
385 public static DoubleMatrix1D applyToIndicesMatchingCriteria(DoubleMatrix1D
386 x, DoubleFunction<Boolean> criterion, DoubleFunction<Double> action) {
387
388 if (x.size() == 0) return x.copy();
389 DoubleMatrix1D result = new cern.colt.matrix.impl.DenseDoubleMatrix1D(x.size());
390
391 for (int i = 0; i < x.size(); i++) {
392 double a = x.getQuick(i);
393 if (criterion.apply(a)) {
394 result.set(i, action.apply(a));
395 } else {
396 result.set(i, a);
397 }
398 }
399 return result;
400 }
401
402
403
404
405
406
407 public static final DoubleMatrix1D stripNonOK(DoubleMatrix1D x, BooleanArrayList ok) {
408
409 assert ok.size() == x.size();
410
411 DoubleArrayList okvals = new DoubleArrayList();
412 for (int i = 0; i < x.size(); i++) {
413 if (ok.get(i)) {
414 okvals.add(x.get(i));
415 }
416 }
417 DoubleMatrix1D answer = new cern.colt.matrix.impl.DenseDoubleMatrix1D(okvals.size());
418
419 for (int i = 0; i < answer.size(); i++) {
420 answer.set(i, okvals.get(i));
421 }
422 return answer;
423 }
424
425
426
427
428
429
430
431
432 public static void replaceValues(DoubleMatrix1D x, BooleanArrayList toReplace, DoubleMatrix1D replacements) {
433
434 if (toReplace.size() != x.size()) {
435 throw new IllegalArgumentException("replacements and x must be the same size");
436 }
437 if (replacements.size() > x.size()) {
438 throw new IllegalArgumentException("Replacements must not outnumber the target");
439 }
440
441 int numToReplace = 0;
442 for (int i = 0; i < toReplace.size(); i++) {
443 if (toReplace.get(i)) {
444 numToReplace++;
445 }
446 }
447
448 if (numToReplace != replacements.size()) {
449 throw new IllegalArgumentException("Number of replacements has to match the nubmer of true values in toReplace");
450 }
451
452 int j = 0;
453 for (int i = 0; i < x.size(); i++) {
454 if (toReplace.get(i)) {
455 x.set(i, replacements.get(j));
456 j++;
457 }
458 }
459
460 }
461
462
463 public static DoubleMatrix1D select( DoubleMatrix1D v, Collection<Integer> selected ) {
464 DoubleMatrix1D result = new DenseDoubleMatrix1D( selected.size() );
465 int k = 0;
466 for ( int i = 0; i < v.size(); i++ ) {
467 if ( selected.contains( i ) ) {
468 result.set( k, v.getQuick( i ) );
469 k++;
470 }
471 }
472 return result;
473 }
474
475 public static DoubleMatrix2D selectColumns( DoubleMatrix2D n, Collection<Integer> selected ) {
476 int ncols = selected.size();
477 DoubleMatrix2D res = new DenseDoubleMatrix2D( n.rows(), ncols );
478 int k = 0;
479 for ( int j = 0; j < n.columns(); j++ ) {
480 if ( !selected.contains( j ) ) {
481 continue;
482 }
483 for ( int i = 0; i < n.rows(); i++ ) {
484 res.set( i, k, n.getQuick( i, j ) );
485 i++;
486 }
487 }
488 return res;
489 }
490
491
492
493
494
495
496 public static DoubleMatrix2D selectColumnsAndRows( DoubleMatrix2D n, Collection<Integer> selected ) {
497 if ( n.rows() != n.columns() ) {
498 throw new IllegalArgumentException( "must be a square matrix" );
499 }
500
501 if ( selected.isEmpty() ) {
502 throw new IllegalArgumentException( "must select more than one" );
503 }
504
505 int columns = selected.size();
506
507 if ( columns == n.columns() ) {
508 return n;
509 }
510
511 if ( columns < 0 ) throw new IllegalArgumentException( "Must leave some columns" );
512 DoubleMatrix2D res = new DenseDoubleMatrix2D( columns, columns );
513 int k = 0;
514 for ( int j = 0; j < n.columns(); j++ ) {
515 if ( !selected.contains( j ) ) {
516 continue;
517 }
518 int m = 0;
519 for ( int i = 0; i < n.rows(); i++ ) {
520 if ( !selected.contains( i ) ) {
521 continue;
522 }
523 res.set( m, k, n.getQuick( i, j ) );
524 m++;
525 }
526 k++;
527 }
528 return res;
529 }
530
531 public static DoubleMatrix2D selectRows( DoubleMatrix2D n, Collection<Integer> selected ) {
532 int nrows = selected.size();
533 DoubleMatrix2D res = new DenseDoubleMatrix2D( nrows, n.columns() );
534 for ( int j = 0; j < n.columns(); j++ ) {
535 int m = 0;
536 for ( int i = 0; i < n.rows(); i++ ) {
537 if ( !selected.contains( i ) ) {
538 continue;
539 }
540 res.set( m, j, n.getQuick( i, j ) );
541 m++;
542 }
543 }
544 return res;
545 }
546
547 public static int sizeWithoutMissingValues( DoubleMatrix1D list ) {
548
549 int size = 0;
550 for ( int i = 0; i < list.size(); i++ ) {
551 double v = list.getQuick( i );
552 if ( !Double.isNaN( v ) && !Double.isInfinite( v ) ) {
553 size++;
554 }
555 }
556 return size;
557 }
558
559
560
561
562
563
564
565 public static DoubleArrayList toList( DoubleMatrix1D vector ) {
566 return new DoubleArrayList( vector.toArray() );
567 }
568
569 }