1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package ubic.basecode.math.linearmodels;
16
17 import java.util.ArrayList;
18 import java.util.Arrays;
19 import java.util.Collection;
20 import java.util.Collections;
21 import java.util.HashSet;
22 import java.util.LinkedHashMap;
23 import java.util.LinkedHashSet;
24 import java.util.List;
25 import java.util.Map;
26 import java.util.Set;
27
28 import org.apache.commons.lang3.StringUtils;
29 import org.slf4j.Logger;
30 import org.slf4j.LoggerFactory;
31
32 import ubic.basecode.dataStructure.matrix.DenseDoubleMatrix;
33 import ubic.basecode.dataStructure.matrix.DoubleMatrix;
34 import ubic.basecode.dataStructure.matrix.ObjectMatrix;
35 import ubic.basecode.dataStructure.matrix.StringMatrix;
36 import cern.colt.matrix.DoubleMatrix2D;
37 import cern.colt.matrix.impl.DenseDoubleMatrix2D;
38
39
40
41
42
43
44
45
46
47
48
49
50
51 public class DesignMatrix {
52
53 private static Logger log = LoggerFactory.getLogger( DesignMatrix.class );
54
55
56
57 private List<Integer> assign = new ArrayList<>();
58
59
60
61
62 private final Set<String> droppedFactors = new HashSet<>();
63
64 private boolean hasIntercept = false;
65
66 private final Set<String[]> interactions = new LinkedHashSet<>();
67
68
69
70
71
72
73 public Collection<String[]> getInteractionTerms() {
74 return interactions;
75 }
76
77
78
79
80 private final Map<String, List<String>> levelsForFactors = new LinkedHashMap<>();
81
82 private DoubleMatrix<String, String> matrix;
83
84
85
86
87 private Map<String, List<Integer>> terms = new LinkedHashMap<>();
88
89
90
91
92 private final Map<String, List<Object>> valuesForFactors = new LinkedHashMap<>();
93
94
95
96
97
98 private boolean strict = true;
99
100
101
102
103
104 public DesignMatrix( Object[] factor, int start, String factorName ) {
105 matrix = this.buildDesign( 1, Arrays.asList( factor ), null, start, factorName );
106 }
107
108
109
110
111 public DesignMatrix( ObjectMatrix<String, String, ? extends Object> sampleInfo ) {
112 this( sampleInfo, true );
113 }
114
115
116
117
118
119 public DesignMatrix( ObjectMatrix<String, String, ?> sampleInfo, boolean intercept ) {
120 matrix = this.designMatrix( sampleInfo, intercept );
121 this.hasIntercept = intercept;
122 if ( sampleInfo.getRowNames().size() == matrix.rows() ) matrix.setRowNames( sampleInfo.getRowNames() );
123
124 }
125
126
127
128
129
130
131
132 public DoubleMatrix2D makeContrasts() {
133
134
135
136
137
138
139
140
141
142
143 throw new RuntimeException();
144 }
145
146 public DesignMatrix( StringMatrix<String, String> sampleInfo ) {
147 this( sampleInfo, true );
148 }
149
150
151
152
153
154
155
156 public DesignMatrix( ObjectMatrix<String, String, Object> design, boolean intercept, boolean strict ) {
157 this( design, intercept );
158 this.strict = strict;
159 }
160
161
162
163
164
165
166 public void add( ObjectMatrix<String, String, Object> sampleInfo ) {
167 this.matrix = this.designMatrix( sampleInfo, this.matrix );
168 }
169
170
171
172
173
174 public void addInteraction() {
175 if ( this.terms.size() != 2 && !hasIntercept ) {
176 throw new IllegalArgumentException( "You must specify which two terms" );
177 }
178
179 if ( this.terms.size() == 2 && hasIntercept || this.terms.size() < 2 ) {
180 throw new IllegalArgumentException( "You need at least two terms" );
181 }
182
183 if ( this.terms.size() > 3 ) {
184 throw new IllegalArgumentException( "You must specify which two terms, there are " + this.terms.size()
185 + " terms: " + StringUtils.join( this.terms.keySet(), "," ) );
186 }
187
188 List<String> iterms = new ArrayList<>();
189 for ( String t : terms.keySet() ) {
190 if ( t.equals( LinearModelSummary.INTERCEPT_COEFFICIENT_NAME ) ) {
191 continue;
192 }
193 iterms.add( t );
194 }
195
196 this.addInteraction( iterms.toArray( new String[] {} ) );
197 }
198
199
200
201
202
203
204
205 public void addInteraction( String... interactionTerms ) {
206
207
208
209
210 for ( String t1 : interactionTerms ) {
211 if ( !this.getLevelsForFactors().containsKey( t1 ) ) {
212 log.warn( "Can't add interaction involving a non-existent or unused terms: " + t1 );
213 return;
214 }
215 }
216
217
218
219
220
221
222 Collection<String> doneTerms = new HashSet<>();
223
224
225 int interactionIndex = terms.size();
226
227 String termName = StringUtils.join( interactionTerms, ":" );
228 Set<String> usedInteractionTerms = new HashSet<>();
229 Arrays.sort( interactionTerms );
230 for ( String t1 : interactionTerms ) {
231
232 if ( doneTerms.contains( t1 ) ) continue;
233 List<Integer> cols1 = terms.get( t1 );
234
235 for ( int i = 0; i < cols1.size(); i++ ) {
236 double[] col1i = this.matrix.getColumn( cols1.get( i ) );
237
238 assert col1i.length > 0;
239
240 for ( String t2 : interactionTerms ) {
241 if ( t1.equals( t2 ) ) continue;
242 doneTerms.add( t2 );
243 List<Integer> cols2 = terms.get( t2 );
244
245 assert cols2 != null;
246
247 for ( int j = 0; j < cols2.size(); j++ ) {
248 double[] col2i = this.matrix.getColumn( cols2.get( j ) );
249
250 Double[] prod = new Double[col1i.length];
251
252 this.matrix = this.copyWithSpace( this.matrix, this.matrix.columns() + 1 );
253 String columnName = null;
254 int numValid = 0;
255 for ( int k = 0; k < col1i.length; k++ ) {
256 prod[k] = col1i[k] * col2i[k];
257 if ( prod[k] != 0 ) numValid++;
258
259
260 if ( prod[k] != 0 && StringUtils.isBlank( columnName ) ) {
261 String if1 = this.valuesForFactors.get( t1 ).get( k ).toString();
262 String if2 = this.valuesForFactors.get( t2 ).get( k ).toString();
263 columnName = t1 + if1 + ":" + t2 + if2;
264 }
265
266 matrix.set( k, this.matrix.columns() - 1, prod[k] );
267 }
268
269 if ( numValid < 2 && strict ) {
270
271
272
273 log.info( "Interaction term " + termName + " won't be estimable, dropping" );
274 matrix = matrix.getColRange( 0, this.matrix.columns() - 2 );
275 continue;
276 }
277
278 boolean redundant = checkForRedundancy( this.matrix, this.matrix.columns() - 2 );
279 if ( redundant && strict ) {
280
281
282
283 log.info( "Interaction term " + termName + " is redundant with another column, dropping" );
284 matrix = matrix.getColRange( 0, this.matrix.columns() - 2 );
285 continue;
286 }
287
288 usedInteractionTerms.add( t1 );
289 usedInteractionTerms.add( t2 );
290
291 matrix.addColumnName( columnName );
292 if ( !this.terms.containsKey( termName ) ) {
293 this.terms.put( termName, new ArrayList<Integer>() );
294 }
295 terms.get( termName ).add( this.matrix.columns() - 1 );
296 assign.add( interactionIndex );
297 }
298 }
299 }
300 }
301
302 if ( !usedInteractionTerms.isEmpty() ) {
303 this.interactions.add( usedInteractionTerms.toArray( new String[] {} ) );
304 }
305
306 }
307
308
309
310
311 public List<Integer> getAssign() {
312 return assign;
313 }
314
315 public String getBaseline( String factorName ) {
316 return this.levelsForFactors.get( factorName ).toString();
317 }
318
319 public DoubleMatrix2D getDoubleMatrix() {
320 return new DenseDoubleMatrix2D( matrix.asDoubles() );
321 }
322
323 public Map<String, List<String>> getLevelsForFactors() {
324 return levelsForFactors;
325 }
326
327 public DoubleMatrix<String, String> getMatrix() {
328 return matrix;
329 }
330
331 public List<String> getTerms() {
332 List<String> result = new ArrayList<>();
333 result.addAll( terms.keySet() );
334 return result;
335 }
336
337 public Map<String, List<Object>> getValuesForFactors() {
338 return valuesForFactors;
339 }
340
341 public boolean hasIntercept() {
342 return this.hasIntercept;
343 }
344
345 public boolean isHasIntercept() {
346 return hasIntercept;
347 }
348
349
350
351
352
353 public void setBaseline( String factorName, String baselineFactorValue ) {
354 if ( !this.levelsForFactors.containsKey( factorName ) ) {
355 throw new IllegalArgumentException( "No factor known by name " + factorName + ", choices are: "
356 + StringUtils.join( this.levelsForFactors.keySet(), "," ) );
357 }
358
359 if ( this.droppedFactors.contains( factorName ) ) {
360 log.warn( "Can't set baseline for a dropped factor, skipping" );
361 return;
362 }
363
364 List<String> oldValues = this.levelsForFactors.get( factorName );
365 int index = oldValues.indexOf( baselineFactorValue );
366 if ( index < 0 ) {
367 throw new IllegalArgumentException( baselineFactorValue + " is not a level of the factor " + factorName );
368 }
369
370 if ( index == 0 ) return;
371
372
373
374
375 List<String> releveled = new ArrayList<>();
376 releveled.add( oldValues.get( index ) );
377 for ( int i = 0; i < oldValues.size(); i++ ) {
378 if ( i == index ) continue;
379 releveled.add( oldValues.get( i ) );
380 }
381 this.levelsForFactors.put( factorName, releveled );
382
383
384
385
386 this.rebuild();
387 }
388
389 @Override
390 public String toString() {
391 return this.matrix.toString();
392 }
393
394
395
396
397 protected void rebuild() {
398 this.matrix = null;
399 this.assign.clear();
400 this.terms.clear();
401
402 if ( this.hasIntercept ) {
403 int nrows = valuesForFactors.values().iterator().next().size();
404 matrix = addIntercept( nrows );
405 }
406
407 int i = 0;
408 for ( String factorName : valuesForFactors.keySet() ) {
409 List<Object> factorValues = valuesForFactors.get( factorName );
410 this.valuesForFactors.put( factorName, factorValues );
411
412 if ( factorValues.get( 0 ) instanceof String && !this.levelsForFactors.containsKey( factorName ) ) {
413 this.levels( factorName, factorValues.toArray( new String[] {} ) );
414 }
415
416 matrix = buildDesign( i + 1, factorValues, matrix, 2, factorName );
417
418 i++;
419 }
420
421 if ( !this.interactions.isEmpty() ) {
422 List<String[]> redoInteractionTerms = new ArrayList<>();
423 for ( String[] interactionTerms : interactions ) {
424 redoInteractionTerms.add( interactionTerms );
425 }
426 this.interactions.clear();
427 for ( String[] t : redoInteractionTerms ) {
428 this.addInteraction( t );
429 }
430 }
431 }
432
433
434
435
436
437
438 private DoubleMatrix<String, String> addContinuousCovariate( List<?> vec,
439 DoubleMatrix<String, String> inputDesign ) {
440 DoubleMatrix<String, String> tmp;
441
442
443
444 log.debug( "Treating factor as continuous covariate" );
445 if ( inputDesign != null ) {
446
447
448
449 assert vec.size() == inputDesign.rows();
450 int numberofColumns = inputDesign.columns() + 1;
451 tmp = copyWithSpace( inputDesign, numberofColumns );
452 } else {
453 tmp = new DenseDoubleMatrix<>( vec.size(), 1 );
454 tmp.assign( 0.0 );
455 }
456 int startcol = 0;
457 if ( inputDesign != null ) {
458 startcol = inputDesign.columns();
459 }
460 for ( int i = startcol; i < tmp.columns(); i++ ) {
461 for ( int j = 0; j < tmp.rows(); j++ ) {
462 tmp.set( j, i, ( Double ) vec.get( j ) );
463 }
464 }
465 return tmp;
466 }
467
468
469
470
471
472 private DoubleMatrix<String, String> addIntercept( int rows ) {
473 DoubleMatrix<String, String> tmp;
474 tmp = new DenseDoubleMatrix<>( rows, 1 );
475 tmp.addColumnName( LinearModelSummary.INTERCEPT_COEFFICIENT_NAME );
476 tmp.assign( 1.0 );
477 this.assign.add( 0 );
478 this.terms.put( LinearModelSummary.INTERCEPT_COEFFICIENT_NAME, new ArrayList<Integer>() );
479 this.terms.get( LinearModelSummary.INTERCEPT_COEFFICIENT_NAME ).add( 0 );
480 return tmp;
481 }
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496 @SuppressWarnings("unchecked")
497 private DoubleMatrix<String, String> buildDesign( int columnNum, List<?> factorValues,
498 DoubleMatrix<String, String> inputDesign, final int start, String factorName ) {
499
500 int startUsed = start;
501
502 if ( !terms.containsKey( factorName ) ) {
503 terms.put( factorName, new ArrayList<Integer>() );
504 }
505 DoubleMatrix<String, String> tmp = null;
506 if ( factorValues.get( 0 ) instanceof Double ) {
507 tmp = addContinuousCovariate( factorValues, inputDesign );
508 this.assign.add( columnNum );
509 terms.get( factorName ).add( columnNum );
510 tmp.addColumnName( factorName );
511 } else {
512
513
514
515 List<String> levels;
516 if ( this.levelsForFactors.containsKey( factorName ) ) {
517 levels = this.levelsForFactors.get( factorName );
518 } else {
519 levels = levels( factorName, ( List<String> ) factorValues );
520 }
521
522
523
524
525
526
527
528
529
530
531
532 tmp = inputDesign;
533
534 List<String> levelList = new ArrayList<>();
535 levelList.addAll( levels );
536
537 int startcol = 0;
538 if ( tmp != null ) {
539 startcol = inputDesign.columns();
540 }
541
542 int currentColumn = startcol;
543 int maxColumn = levels.size() + startcol - ( startUsed - 1 );
544 Collection<String> usedLevels = new HashSet<>();
545 for ( int i = startcol; i < maxColumn; i++ ) {
546
547
548
549 int currentLevelIndex = i - startcol + ( startUsed - 1 );
550
551 if ( currentLevelIndex >= levelList.size() ) {
552 if ( startUsed > 1 ) {
553
554 currentLevelIndex = 0;
555 }
556 }
557
558 String level = levelList.get( currentLevelIndex );
559 log.debug( "Adding column for Level=" + level + " at index " + currentLevelIndex );
560
561
562 if ( tmp != null ) {
563 tmp = copyWithSpace( tmp, tmp.columns() + 1 );
564 } else {
565 tmp = new DenseDoubleMatrix<>( factorValues.size(), 1 );
566 tmp.assign( 0.0 );
567 }
568
569 String contrastingValue = "";
570 assert tmp != null;
571 for ( int j = 0; j < tmp.rows(); j++ ) {
572 Object fv = factorValues.get(j);
573
574 if (fv == null) {
575
576
577 throw new IllegalArgumentException("Null value for factor " + factorName + " at row " + j);
578 }
579
580 boolean isBaseline = !fv.equals( level );
581 if ( !isBaseline ) {
582 contrastingValue = ( String ) fv;
583 }
584 tmp.set( j, currentColumn, isBaseline ? 0.0 : 1.0 );
585 }
586
587
588
589
590
591
592
593
594
595
596
597 currentColumn++;
598
599 if ( StringUtils.isBlank( contrastingValue ) ) {
600 contrastingValue = "_" + i;
601 }
602 tmp.setColumnName( factorName + contrastingValue, currentColumn );
603 this.assign.add( columnNum );
604 terms.get( factorName ).add( i );
605 usedLevels.add( level );
606 }
607
608
609
610
611
612 }
613
614 return tmp;
615 }
616
617
618
619
620
621
622
623
624 private boolean checkForRedundancy( DoubleMatrix<String, String> tmp, int column ) {
625 for ( int p = 0; p < column; p++ ) {
626
627 boolean foundRedundant = true;
628 for ( int v = 0; v < tmp.rows(); v++ ) {
629 if ( tmp.get( v, column ) != tmp.get( v, p ) ) {
630 foundRedundant = false;
631 break;
632 }
633 }
634
635 if ( foundRedundant ) {
636 return true;
637 }
638 }
639 return false;
640 }
641
642
643
644
645
646
647
648
649 private DoubleMatrix<String, String> copyWithSpace( DoubleMatrix<String, String> inputDesign,
650 int numberofColumns ) {
651 DoubleMatrix<String, String> tmp;
652 tmp = new DenseDoubleMatrix<>( inputDesign.rows(), numberofColumns );
653 tmp.assign( 0.0 );
654
655 for ( int i = 0; i < inputDesign.rows(); i++ ) {
656 for ( int j = 0; j < inputDesign.columns(); j++ ) {
657 String colName = inputDesign.getColName( j );
658 if ( i == 0 && colName != null ) {
659 tmp.setColumnName( colName, j );
660 }
661 tmp.set( i, j, inputDesign.get( i, j ) );
662 }
663 }
664
665 if ( !inputDesign.getRowNames().isEmpty() ) tmp.setRowNames( inputDesign.getRowNames() );
666 return tmp;
667 }
668
669
670
671
672
673
674
675
676 private DoubleMatrix<String, String> designMatrix( ObjectMatrix<String, String, ?> sampleInfo, boolean intercept ) {
677 DoubleMatrix<String, String> tmp = null;
678 if ( intercept ) {
679 int rows = sampleInfo.rows();
680 tmp = addIntercept( rows );
681 }
682 return designMatrix( sampleInfo, tmp );
683 }
684
685
686
687
688
689
690 private DoubleMatrix<String, String> designMatrix( ObjectMatrix<String, String, ?> sampleInfo,
691 DoubleMatrix<String, String> design ) {
692 for ( int i = 0; i < sampleInfo.columns(); i++ ) {
693 Object[] factorValuesAr = sampleInfo.getColumn( i );
694 List<Object> factorValues = Arrays.asList( factorValuesAr );
695 design = buildDesign( i + 1, factorValues, design, 2, sampleInfo.getColName( i ) );
696 this.valuesForFactors.put( sampleInfo.getColName( i ), factorValues );
697 }
698 return design;
699 }
700
701
702
703
704
705 private List<String> levels( String factorName, List<String> vec ) {
706 return this.levels( factorName, vec.toArray( new String[] {} ) );
707 }
708
709
710
711
712
713
714 private List<String> levels( String factorName, String[] vec ) {
715 Set<String> flevs = new LinkedHashSet<>();
716 for ( String v : vec ) {
717 flevs.add( v );
718 }
719 List<String> result = new ArrayList<>();
720 for ( String fl : flevs ) {
721 result.add( fl );
722 }
723 this.levelsForFactors.put( factorName, result );
724 return result;
725 }
726
727 }