1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package ubic.basecode.math.linearmodels;
16
17 import java.io.File;
18 import java.io.IOException;
19 import java.io.OutputStream;
20 import java.io.PrintStream;
21 import java.util.ArrayList;
22 import java.util.Arrays;
23 import java.util.Collection;
24 import java.util.HashMap;
25 import java.util.LinkedHashMap;
26 import java.util.List;
27 import java.util.Map;
28 import java.util.Set;
29 import java.util.TreeMap;
30 import java.util.TreeSet;
31
32 import org.apache.commons.lang3.ArrayUtils;
33 import org.apache.commons.lang3.StringUtils;
34 import org.apache.commons.lang3.time.StopWatch;
35 import org.apache.commons.math3.distribution.FDistribution;
36 import org.apache.commons.math3.distribution.TDistribution;
37 import org.apache.commons.math3.exception.NotStrictlyPositiveException;
38 import org.slf4j.Logger;
39 import org.slf4j.LoggerFactory;
40
41 import cern.colt.bitvector.BitVector;
42 import cern.colt.list.DoubleArrayList;
43 import cern.colt.matrix.DoubleMatrix1D;
44 import cern.colt.matrix.DoubleMatrix2D;
45 import cern.colt.matrix.impl.DenseDoubleMatrix1D;
46 import cern.colt.matrix.impl.DenseDoubleMatrix2D;
47 import cern.colt.matrix.linalg.Algebra;
48 import cern.jet.math.Functions;
49 import cern.jet.stat.Descriptive;
50 import ubic.basecode.dataStructure.matrix.DoubleMatrix;
51 import ubic.basecode.dataStructure.matrix.DoubleMatrixFactory;
52 import ubic.basecode.dataStructure.matrix.MatrixUtil;
53 import ubic.basecode.dataStructure.matrix.ObjectMatrix;
54 import ubic.basecode.math.Constants;
55 import ubic.basecode.math.linalg.QRDecomposition;
56 import ubic.basecode.util.r.type.AnovaEffect;
57
58
59
60
61
62
63
64
65
66
67 public class LeastSquaresFit {
68
69 private static Logger log = LoggerFactory.getLogger(LeastSquaresFit.class);
70
71
72
73
74 boolean hasBeenShrunken = false;
75
76
77
78
79 private DoubleMatrix2D A;
80
81
82
83
84
85 private List<Integer> assign = new ArrayList<>();
86
87
88
89
90 private List<List<Integer>> assigns = new ArrayList<>();
91
92
93
94
95 private DoubleMatrix2D b;
96
97
98
99
100 private DoubleMatrix2D coefficients = null;
101
102
103
104
105 private DesignMatrix designMatrix;
106
107
108
109
110 private double dfPrior = 0;
111
112
113
114
115 private DoubleMatrix2D fitted;
116
117
118
119
120 private boolean hasIntercept = true;
121
122
123
124
125 private boolean hasMissing = false;
126
127
128
129
130 private QRDecomposition qr = null;
131
132
133
134
135
136
137 private Map<BitVector, QRDecomposition> qrs = new HashMap<>();
138
139
140
141
142
143 private Map<Integer, QRDecomposition> qrsForWeighted = new HashMap<>();
144
145 private int residualDof = -1;
146
147
148
149
150 private List<Integer> residualDofs = new ArrayList<>();
151
152
153
154
155 private DoubleMatrix2D residuals = null;
156
157
158
159
160 private List<String> rowNames;
161
162
163
164
165 private Map<Integer, DoubleMatrix1D> stdevUnscaled = new TreeMap<>();
166
167
168
169
170 private List<String> terms;
171
172
173
174
175 private Map<Integer, BitVector> valuesPresentMap = new HashMap<>();
176
177
178
179
180 private DoubleMatrix1D varPost = null;
181
182
183
184
185 private Double varPrior = null;
186
187
188
189
190 private DoubleMatrix2D weights = null;
191
192
193
194
195
196
197
198 public LeastSquaresFit(DesignMatrix designMatrix, DoubleMatrix<String, String> data) {
199 this.designMatrix = designMatrix;
200 this.A = designMatrix.getDoubleMatrix();
201 this.assign = designMatrix.getAssign();
202 this.terms = designMatrix.getTerms();
203
204 this.rowNames = data.getRowNames();
205 this.b = new DenseDoubleMatrix2D(data.asArray());
206 boolean hasInterceptTerm = this.terms.contains(LinearModelSummary.INTERCEPT_COEFFICIENT_NAME);
207 this.hasIntercept = designMatrix.hasIntercept();
208 assert hasInterceptTerm == this.hasIntercept : diagnosis(null);
209 fit();
210 }
211
212
213
214
215
216
217
218
219 public LeastSquaresFit(DesignMatrix designMatrix, DoubleMatrix<String, String> data,
220 final DoubleMatrix2D weights) {
221 this.designMatrix = designMatrix;
222 DoubleMatrix2D X = designMatrix.getDoubleMatrix();
223 this.assign = designMatrix.getAssign();
224 this.terms = designMatrix.getTerms();
225 this.A = X;
226 this.rowNames = data.getRowNames();
227 this.b = new DenseDoubleMatrix2D(data.asArray());
228 boolean hasInterceptTerm = this.terms.contains(LinearModelSummary.INTERCEPT_COEFFICIENT_NAME);
229 this.hasIntercept = designMatrix.hasIntercept();
230 assert hasInterceptTerm == this.hasIntercept : diagnosis(null);
231 this.weights = weights;
232 fit();
233 }
234
235
236
237
238
239
240
241
242 public LeastSquaresFit(DesignMatrix designMatrix, DoubleMatrix2D b, final DoubleMatrix2D weights) {
243
244 this.designMatrix = designMatrix;
245 DoubleMatrix2D X = designMatrix.getDoubleMatrix();
246 this.assign = designMatrix.getAssign();
247 this.terms = designMatrix.getTerms();
248 this.A = X;
249 this.b = b;
250 boolean hasInterceptTerm = this.terms.contains(LinearModelSummary.INTERCEPT_COEFFICIENT_NAME);
251 this.hasIntercept = designMatrix.hasIntercept();
252 assert hasInterceptTerm == this.hasIntercept : diagnosis(null);
253
254 this.weights = weights;
255
256 fit();
257
258 }
259
260
261
262
263
264
265
266 public LeastSquaresFit(DoubleMatrix1D vectorA, DoubleMatrix1D vectorB) {
267 assert vectorA.size() == vectorB.size();
268
269 this.A = new DenseDoubleMatrix2D(vectorA.size(), 2);
270 this.b = new DenseDoubleMatrix2D(1, vectorB.size());
271
272 for (int i = 0; i < vectorA.size(); i++) {
273 A.set(i, 0, 1);
274 A.set(i, 1, vectorA.get(i));
275 b.set(0, i, vectorB.get(i));
276 }
277
278 fit();
279 }
280
281
282
283
284
285
286
287
288 public LeastSquaresFit(DoubleMatrix1D vectorA, DoubleMatrix1D vectorB, final DoubleMatrix1D weights) {
289
290 assert vectorA.size() == vectorB.size();
291 assert vectorA.size() == weights.size();
292
293 this.A = new DenseDoubleMatrix2D(vectorA.size(), 2);
294 this.b = new DenseDoubleMatrix2D(1, vectorB.size());
295 this.weights = new DenseDoubleMatrix2D(1, weights.size());
296
297 for (int i = 0; i < vectorA.size(); i++) {
298
299 A.set(i, 0, 1);
300 A.set(i, 1, vectorA.get(i));
301 b.set(0, i, vectorB.get(i));
302 this.weights.set(0, i, weights.get(i));
303 }
304
305 fit();
306 }
307
308
309
310
311
312
313
314 public LeastSquaresFit(DoubleMatrix2D A, DoubleMatrix2D b) {
315 this.A = A;
316 this.b = b;
317 fit();
318 }
319
320
321
322
323
324
325
326
327 public LeastSquaresFit(DoubleMatrix2D A, DoubleMatrix2D b, final DoubleMatrix2D weights) {
328 assert A != null;
329 assert b != null;
330 assert A.rows() == b.columns();
331 assert weights == null || b.columns() == weights.columns();
332 assert weights == null || b.rows() == weights.rows();
333
334 this.A = A;
335 this.b = b;
336 this.weights = weights;
337
338 fit();
339
340 }
341
342
343
344
345
346 public LeastSquaresFit(ObjectMatrix<String, String, Object> sampleInfo, DenseDoubleMatrix2D data) {
347
348 this.designMatrix = new DesignMatrix(sampleInfo, true);
349
350 this.hasIntercept = true;
351 this.A = designMatrix.getDoubleMatrix();
352 this.assign = designMatrix.getAssign();
353 this.terms = designMatrix.getTerms();
354
355 this.b = data;
356 fit();
357 }
358
359
360
361
362
363
364 public LeastSquaresFit(ObjectMatrix<String, String, Object> sampleInfo, DenseDoubleMatrix2D data,
365 boolean interactions) {
366 this.designMatrix = new DesignMatrix(sampleInfo, true);
367
368 if (interactions) {
369 addInteraction();
370 }
371
372 this.A = designMatrix.getDoubleMatrix();
373 this.assign = designMatrix.getAssign();
374 this.terms = designMatrix.getTerms();
375
376 this.b = data;
377 fit();
378 }
379
380
381
382
383
384
385
386 public LeastSquaresFit(ObjectMatrix<String, String, Object> design, DoubleMatrix<String, String> b) {
387 this.designMatrix = new DesignMatrix(design, true);
388
389 this.A = designMatrix.getDoubleMatrix();
390 this.assign = designMatrix.getAssign();
391 this.terms = designMatrix.getTerms();
392
393 this.b = new DenseDoubleMatrix2D(b.asArray());
394 this.rowNames = b.getRowNames();
395 fit();
396 }
397
398
399
400
401
402
403
404 public LeastSquaresFit(ObjectMatrix<String, String, Object> design, DoubleMatrix<String, String> data,
405 boolean interactions) {
406 this.designMatrix = new DesignMatrix(design, true);
407
408 if (interactions) {
409 addInteraction();
410 }
411
412 DoubleMatrix2D X = designMatrix.getDoubleMatrix();
413 this.assign = designMatrix.getAssign();
414 this.terms = designMatrix.getTerms();
415 this.A = X;
416 this.b = new DenseDoubleMatrix2D(data.asArray());
417 fit();
418 }
419
420
421
422
423
424
425
426 public DoubleMatrix2D getCoefficients() {
427 return coefficients;
428 }
429
430 public double getDfPrior() {
431 return dfPrior;
432 }
433
434 public DoubleMatrix2D getFitted() {
435 return fitted;
436 }
437
438 public int getResidualDof() {
439 return residualDof;
440 }
441
442 public List<Integer> getResidualDofs() {
443 return residualDofs;
444 }
445
446 public DoubleMatrix2D getResiduals() {
447 return residuals;
448 }
449
450
451
452
453 public DoubleMatrix2D getStudentizedResiduals() {
454 int dof = this.residualDof - 1;
455
456 assert dof > 0;
457
458 if (this.hasMissing) {
459 throw new UnsupportedOperationException("Studentizing not supported with missing values");
460 }
461
462 DoubleMatrix2D result = this.residuals.like();
463
464
465
466
467 DoubleMatrix2D q = this.getQR(0).getQ();
468
469 DoubleMatrix1D hatdiag = new DenseDoubleMatrix1D(residuals.columns());
470 for (int j = 0; j < residuals.columns(); j++) {
471 double hj = q.viewRow(j).aggregate(Functions.plus, Functions.square);
472 if (1.0 - hj < Constants.TINY) {
473 hj = 1.0;
474 }
475 hatdiag.set(j, hj);
476 }
477
478
479
480
481 for (int i = 0; i < residuals.rows(); i++) {
482
483
484
485
486 DoubleMatrix1D residualRow = residuals.viewRow(i);
487
488 if (this.weights != null) {
489
490 DoubleMatrix1D w = weights.viewRow(i).copy().assign(Functions.sqrt);
491 residualRow = residualRow.copy().assign(w, Functions.mult);
492 }
493
494 double sum = residualRow.aggregate(Functions.plus, Functions.square);
495
496 for (int j = 0; j < residualRow.size(); j++) {
497
498 double hj = hatdiag.get(j);
499
500
501 double sigma;
502
503 if (hj < 1.0) {
504 sigma = Math.sqrt((sum - Math.pow(residualRow.get(j), 2) / (1.0 - hj)) / dof);
505 } else {
506 sigma = Math.sqrt(sum / dof);
507 }
508
509 double res = residualRow.getQuick(j);
510 double studres = res / (sigma * Math.sqrt(1.0 - hj));
511
512 if (log.isDebugEnabled()) log.debug("sigma=" + sigma + " hj=" + hj + " stres=" + studres);
513
514 result.set(i, j, studres);
515 }
516 }
517 return result;
518 }
519
520 public DoubleMatrix1D getVarPost() {
521 return varPost;
522 }
523
524 public double getVarPrior() {
525 return varPrior;
526 }
527
528 public DoubleMatrix2D getWeights() {
529 return weights;
530 }
531
532 public boolean isHasBeenShrunken() {
533 return hasBeenShrunken;
534 }
535
536 public boolean isHasMissing() {
537 return hasMissing;
538 }
539
540
541
542
543
544 public List<LinearModelSummary> summarize() {
545 return this.summarize(false);
546 }
547
548
549
550
551
552 public List<LinearModelSummary> summarize(boolean anova) {
553 List<LinearModelSummary> lmsresults = new ArrayList<>();
554
555 List<GenericAnovaResult> anovas = null;
556 if (anova) {
557 anovas = this.anova();
558 }
559
560 StopWatch timer = new StopWatch();
561 timer.start();
562 log.info("Summarizing");
563 for (int i = 0; i < this.coefficients.columns(); i++) {
564 LinearModelSummary lms = summarize(i);
565 lms.setAnova(anovas != null ? anovas.get(i) : null);
566 lmsresults.add(lms);
567 if (timer.getTime() > 10000 && i > 0 && i % 10000 == 0) {
568 log.info("Summarized " + i);
569 }
570 }
571 log.info("Summzarized " + this.coefficients.columns() + " results");
572
573 return lmsresults;
574 }
575
576
577
578
579
580
581
582 public Map<String, LinearModelSummary> summarizeByKeys(boolean anova) {
583 List<LinearModelSummary> summaries = this.summarize(anova);
584 Map<String, LinearModelSummary> result = new LinkedHashMap<>();
585 for (LinearModelSummary lms : summaries) {
586 if (StringUtils.isBlank(lms.getKey())) {
587
588
589
590 throw new IllegalStateException("Key must not be blank");
591 }
592
593 if (result.containsKey(lms.getKey())) {
594 throw new IllegalStateException("Duplicate key " + lms.getKey());
595 }
596 result.put(lms.getKey(), lms);
597 }
598 return result;
599 }
600
601
602
603
604
605
606
607
608
609
610 protected List<GenericAnovaResult> anova() {
611
612 DoubleMatrix1D ones = new DenseDoubleMatrix1D(residuals.columns());
613 ones.assign(1.0);
614
615
616
617
618 DoubleMatrix1D residualSumsOfSquares;
619
620 if (this.weights == null) {
621 residualSumsOfSquares = MatrixUtil.multWithMissing(residuals.copy().assign(Functions.square),
622 ones);
623 } else {
624 residualSumsOfSquares = MatrixUtil.multWithMissing(
625 residuals.copy().assign(this.weights.copy().assign(Functions.sqrt), Functions.mult).assign(Functions.square),
626 ones);
627 }
628
629 DoubleMatrix2D effects = null;
630 if (this.hasMissing || this.weights != null) {
631 effects = new DenseDoubleMatrix2D(this.b.rows(), this.A.columns());
632 effects.assign(Double.NaN);
633 for (int i = 0; i < this.b.rows(); i++) {
634 QRDecomposition qrd = this.getQR(i);
635 if (qrd == null) {
636
637 for (int j = 0; j < effects.columns(); j++) {
638 effects.set(i, j, Double.NaN);
639 }
640 continue;
641 }
642
643
644
645
646 DoubleMatrix1D brow = b.viewRow(i);
647 DoubleMatrix1D browWithoutMissing = MatrixUtil.removeMissingOrInfinite(brow);
648
649 DoubleMatrix1D tqty;
650 if (weights != null) {
651 DoubleMatrix1D w = MatrixUtil.removeMissingOrInfinite(brow, this.weights.viewRow(i).copy().assign(Functions.sqrt));
652 assert w.size() == browWithoutMissing.size();
653 DoubleMatrix1D bw = browWithoutMissing.copy().assign(w, Functions.mult);
654 tqty = qrd.effects(bw);
655 } else {
656 tqty = qrd.effects(browWithoutMissing);
657 }
658
659
660 for (int j = 0; j < qrd.getRank(); j++) {
661 effects.set(i, j, tqty.get(j));
662 }
663 }
664
665 } else {
666 assert this.qr != null;
667 effects = qr.effects(this.b.viewDice().copy()).viewDice();
668 }
669
670
671 effects.assign(Functions.square);
672
673
674
675
676 Set<Integer> facs = new TreeSet<>();
677 facs.addAll(assign);
678
679 DoubleMatrix2D ssq = new DenseDoubleMatrix2D(effects.rows(), facs.size() + 1);
680 DoubleMatrix2D dof = new DenseDoubleMatrix2D(effects.rows(), facs.size() + 1);
681 dof.assign(0.0);
682 ssq.assign(0.0);
683 List<Integer> assignToUse = assign;
684
685 for (int i = 0; i < ssq.rows(); i++) {
686
687 ssq.set(i, facs.size(), residualSumsOfSquares.get(i));
688 int rdof;
689 if (this.residualDofs.isEmpty()) {
690 rdof = this.residualDof;
691 } else {
692 rdof = this.residualDofs.get(i);
693 }
694
695
696
697 dof.set(i, facs.size(), rdof);
698
699 if (!assigns.isEmpty()) {
700
701 assignToUse = assigns.get(i);
702 }
703
704
705 DoubleMatrix1D effectsForRow = effects.viewRow(i);
706
707 if (assignToUse.size() != effectsForRow.size()) {
708
709
710
711 log.debug("Check me: effects has missing values");
712 }
713
714 for (int j = 0; j < assignToUse.size(); j++) {
715
716 double valueToAdd = effectsForRow.get(j);
717 int col = assignToUse.get(j);
718 if (col > 0 && !this.hasIntercept) {
719 col = col - 1;
720 }
721
722
723
724
725
726
727 if (!Double.isNaN(valueToAdd) && valueToAdd > Constants.SMALL) {
728 ssq.set(i, col, ssq.get(i, col) + valueToAdd);
729 dof.set(i, col, dof.get(i, col) + 1);
730 }
731 }
732 }
733
734 DoubleMatrix1D denominator;
735 if (this.hasBeenShrunken) {
736 denominator = this.varPost.copy();
737 } else {
738 if (this.residualDofs.isEmpty()) {
739
740 denominator = residualSumsOfSquares.copy().assign(Functions.div(residualDof));
741 } else {
742 denominator = new DenseDoubleMatrix1D(residualSumsOfSquares.size());
743 for (int i = 0; i < residualSumsOfSquares.size(); i++) {
744 denominator.set(i, residualSumsOfSquares.get(i) / residualDofs.get(i));
745 }
746 }
747 }
748
749
750 DoubleMatrix2D fStats = ssq.copy().assign(dof, Functions.div);
751 DoubleMatrix2D pvalues = fStats.like();
752 computeStats(dof, fStats, denominator, pvalues);
753
754 return summarizeAnova(ssq, dof, fStats, pvalues);
755 }
756
757
758
759
760
761
762
763
764 protected void ebayesUpdate(double d, double v, DoubleMatrix1D vp) {
765 this.dfPrior = d;
766 this.varPrior = v;
767 this.varPost = vp;
768 this.hasBeenShrunken = true;
769 }
770
771
772
773
774
775
776
777
778
779
780
781
782 protected LinearModelSummary summarize(int i) {
783
784 String key = null;
785 if (this.rowNames != null) {
786 key = this.rowNames.get(i);
787 if (key == null) log.warn("Key null at " + i);
788 }
789
790 QRDecomposition qrd = null;
791 qrd = this.getQR(i);
792
793 if (qrd == null) {
794 log.debug("QR was null for item " + i);
795 return new LinearModelSummary(key);
796 }
797
798 int rdf;
799 if (this.residualDofs.isEmpty()) {
800 rdf = this.residualDof;
801 } else {
802 rdf = this.residualDofs.get(i);
803 }
804 assert !Double.isNaN(rdf);
805
806 if (rdf == 0) {
807 return new LinearModelSummary(key);
808 }
809
810 DoubleMatrix1D resid = MatrixUtil.removeMissingOrInfinite(this.residuals.viewRow(i));
811 DoubleMatrix1D f = MatrixUtil.removeMissingOrInfinite(fitted.viewRow(i));
812
813 DoubleMatrix1D rweights = null;
814 DoubleMatrix1D sqrtweights = null;
815 if (this.weights != null) {
816 rweights = MatrixUtil.removeMissingOrInfinite(fitted.viewRow(i), this.weights.viewRow(i).copy());
817 sqrtweights = rweights.copy().assign(Functions.sqrt);
818 } else {
819 rweights = new DenseDoubleMatrix1D(f.size()).assign(1.0);
820 sqrtweights = rweights.copy();
821 }
822
823 DoubleMatrix1D allCoef = coefficients.viewColumn(i);
824 DoubleMatrix1D estCoef = MatrixUtil.removeMissingOrInfinite(allCoef);
825
826 if (estCoef.size() == 0) {
827 log.warn("No coefficients estimated for row " + i + this.diagnosis(qrd));
828 log.info("Data for this row:\n" + this.b.viewRow(i));
829 return new LinearModelSummary(key);
830 }
831
832 int rank = qrd.getRank();
833 int n = qrd.getQ().rows();
834 assert rdf == n - rank : "Rank was not correct, expected " + rdf + " but got Q rows=" + n + ", #Coef=" + rank
835 + diagnosis(qrd);
836
837
838
839
840
841
842
843
844
845
846
847
848
849 double mss;
850
851 if (weights != null) {
852
853 if (hasIntercept) {
854
855 double m = f.copy().assign(Functions.div(rweights.zSum())).assign(rweights, Functions.mult).zSum();
856
857 mss = f.copy().assign(Functions.minus(m)).assign(Functions.square).assign(rweights, Functions.mult).zSum();
858 } else {
859 mss = f.copy().assign(Functions.square).assign(rweights, Functions.mult).zSum();
860 }
861
862 assert resid.size() == rweights.size();
863 } else {
864 if (hasIntercept) {
865 mss = f.copy().assign(Functions.minus(Descriptive.mean(new DoubleArrayList(f.toArray()))))
866 .assign(Functions.square).zSum();
867 } else {
868 mss = f.copy().assign(Functions.square).zSum();
869 }
870 }
871
872 double rss = resid.copy().assign(Functions.square).assign(rweights, Functions.mult).zSum();
873 if (weights != null) resid = resid.copy().assign(sqrtweights, Functions.mult);
874
875 double resvar = rss / rdf;
876
877
878 DoubleMatrix<String, String> summaryTable = DoubleMatrixFactory.dense(allCoef.size(), 4);
879 summaryTable.assign(Double.NaN);
880 summaryTable
881 .setColumnNames(Arrays.asList(new String[]{"Estimate", "Std. Error", "t value", "Pr(>|t|)"}));
882
883
884
885 DoubleMatrix2D XtXi = qrd.chol2inv();
886
887
888 DoubleMatrix1D sdUnscaled = MatrixUtil.diagonal(XtXi).assign(Functions.sqrt);
889
890
891
892
893 this.stdevUnscaled.put(i, sdUnscaled);
894
895 DoubleMatrix1D sdScaled = MatrixUtil
896 .removeMissingOrInfinite(MatrixUtil.diagonal(XtXi).assign(Functions.mult(resvar))
897 .assign(Functions.sqrt));
898
899
900
901 DoubleMatrix1D effects = qrd.effects(MatrixUtil.removeMissingOrInfinite(this.b.viewRow(i).copy()).assign(sqrtweights, Functions.mult));
902
903
904
905
906
907
908
909
910
911 double sigma = Math.sqrt(
912 effects.copy().viewPart(rank, effects.size() - rank).aggregate(Functions.plus, Functions.square) / (effects.size() - rank));
913
914
915
916
917
918 DoubleMatrix1D tstats;
919 TDistribution tdist;
920 if (this.hasBeenShrunken) {
921
922
923
924
925 tstats = estCoef.copy().assign(sdUnscaled, Functions.div).assign(
926 Functions.div(Math.sqrt(this.varPost.get(i))));
927
928
929
930
931
932
933
934
935
936 double dfTotal = rdf + this.dfPrior;
937
938 assert !Double.isNaN(dfTotal);
939 tdist = new TDistribution(dfTotal);
940 } else {
941
942
943
944
945
946 tstats = estCoef.copy().assign(sdScaled, Functions.div);
947 tdist = new TDistribution(rdf);
948 }
949
950 int j = 0;
951 for (int ti = 0; ti < allCoef.size(); ti++) {
952 double c = allCoef.get(ti);
953 assert this.designMatrix != null;
954 List<String> colNames = this.designMatrix.getMatrix().getColNames();
955
956 String dmcolname;
957 if (colNames == null) {
958 dmcolname = "Column_" + ti;
959 } else {
960 dmcolname = colNames.get(ti);
961 }
962
963 summaryTable.addRowName(dmcolname);
964 if (Double.isNaN(c)) {
965 continue;
966 }
967
968 summaryTable.set(ti, 0, estCoef.get(j));
969 summaryTable.set(ti, 1, sdUnscaled.get(j));
970 summaryTable.set(ti, 2, tstats.get(j));
971
972 double pval = 2.0 * (1.0 - tdist.cumulativeProbability(Math.abs(tstats.get(j))));
973 summaryTable.set(ti, 3, pval);
974
975 j++;
976
977 }
978
979 double rsquared = 0.0;
980 double adjRsquared = 0.0;
981 double fstatistic = 0.0;
982 int numdf = 0;
983 int dendf = 0;
984
985 if (terms.size() > 1 || !hasIntercept) {
986 int dfint = hasIntercept ? 1 : 0;
987 rsquared = mss / (mss + rss);
988 adjRsquared = 1 - (1 - rsquared) * ((n - dfint) / (double) rdf);
989
990 fstatistic = mss / (rank - dfint) / resvar;
991
992
993 numdf = rank - dfint;
994 dendf = rdf;
995
996 } else {
997
998 rsquared = 0.0;
999 adjRsquared = 0.0;
1000 }
1001
1002
1003
1004 LinearModelSummaryrModelSummary.html#LinearModelSummary">LinearModelSummary lms = new LinearModelSummary(key, ArrayUtils.toObject(allCoef.toArray()),
1005 ArrayUtils.toObject(resid
1006 .toArray()),
1007 terms,
1008 summaryTable, ArrayUtils.toObject(effects.toArray()),
1009 ArrayUtils.toObject(sdUnscaled.toArray()), rsquared,
1010 adjRsquared,
1011 fstatistic,
1012 numdf, dendf, null, sigma, this.hasBeenShrunken);
1013 lms.setPriorDof(this.dfPrior);
1014
1015 return lms;
1016 }
1017
1018
1019
1020
1021 private void addInteraction() {
1022 if (designMatrix.getTerms().size() == 1) {
1023 throw new IllegalArgumentException("Need at least two factors for interactions");
1024 }
1025 if (designMatrix.getTerms().size() != 2) {
1026 throw new UnsupportedOperationException("Interactions not supported for more than two factors");
1027 }
1028 this.designMatrix.addInteraction(designMatrix.getTerms().get(0), designMatrix.getTerms().get(1));
1029 }
1030
1031
1032
1033
1034
1035
1036
1037
1038 private void addQR(Integer row, BitVector valuesPresent, QRDecomposition newQR) {
1039
1040
1041
1042
1043
1044
1045 if (this.weights != null) {
1046 this.qrsForWeighted.put(row, newQR);
1047 return;
1048 }
1049
1050 assert row != null;
1051
1052 if (valuesPresent == null) {
1053 valuesPresentMap.put(row, null);
1054 }
1055
1056 QRDecomposition cachedQr = qrs.get(valuesPresent);
1057 if (cachedQr == null) {
1058 qrs.put(valuesPresent, newQR);
1059 }
1060 valuesPresentMap.put(row, valuesPresent);
1061 }
1062
1063
1064
1065
1066 private void checkForMissingValues() {
1067 for (int i = 0; i < b.rows(); i++) {
1068 for (int j = 0; j < b.columns(); j++) {
1069 double v = b.get(i, j);
1070 if (Double.isNaN(v) || Double.isInfinite(v)) {
1071 this.hasMissing = true;
1072 log.info("Data has missing values (at row=" + (i + 1) + " column=" + (j + 1));
1073 break;
1074 }
1075 }
1076 if (this.hasMissing) break;
1077 }
1078 }
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093 private DoubleMatrix2D cleanDesign(final DoubleMatrix2D design, int ypsize, List<Integer> droppedColumns) {
1094
1095
1096
1097
1098 for (int j = 0; j < design.columns(); j++) {
1099 if (j == 0 && this.hasIntercept) continue;
1100 double lastValue = Double.NaN;
1101 boolean constant = true;
1102 for (int i = 0; i < design.rows(); i++) {
1103 double thisvalue = design.get(i, j);
1104 if (i > 0 && thisvalue != lastValue) {
1105 constant = false;
1106 break;
1107 }
1108 lastValue = thisvalue;
1109 }
1110 if (constant) {
1111 log.debug("Dropping constant column " + j);
1112 droppedColumns.add(j);
1113 continue;
1114 }
1115
1116 DoubleMatrix1D col = design.viewColumn(j);
1117
1118 for (int p = 0; p < j; p++) {
1119 boolean redundant = true;
1120 DoubleMatrix1D otherCol = design.viewColumn(p);
1121 for (int v = 0; v < col.size(); v++) {
1122 if (col.get(v) != otherCol.get(v)) {
1123 redundant = false;
1124 break;
1125 }
1126 }
1127 if (redundant) {
1128 log.debug("Dropping redundant column " + j);
1129 droppedColumns.add(j);
1130 break;
1131 }
1132 }
1133
1134 }
1135
1136 DoubleMatrix2D returnValue = MatrixUtil.dropColumns(design, droppedColumns);
1137
1138 return returnValue;
1139 }
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149 private void computeStats(DoubleMatrix2D dof, DoubleMatrix2D fStats, DoubleMatrix1D denominator,
1150 DoubleMatrix2D pvalues) {
1151 pvalues.assign(Double.NaN);
1152 int timesWarned = 0;
1153 for (int i = 0; i < fStats.rows(); i++) {
1154
1155 int rdof;
1156 if (this.residualDofs.isEmpty()) {
1157 rdof = residualDof;
1158 } else {
1159 rdof = this.residualDofs.get(i);
1160 }
1161
1162 for (int j = 0; j < fStats.columns(); j++) {
1163
1164 double ndof = dof.get(i, j);
1165
1166 if (ndof <= 0 || rdof <= 0) {
1167 pvalues.set(i, j, Double.NaN);
1168 fStats.set(i, j, Double.NaN);
1169 continue;
1170 }
1171
1172 if (j == fStats.columns() - 1) {
1173
1174 pvalues.set(i, j, Double.NaN);
1175 fStats.set(i, j, Double.NaN);
1176 continue;
1177 }
1178
1179
1180
1181
1182 if (fStats.get(i, j) < Constants.SMALLISH && denominator.get(i) < Constants.SMALLISH) {
1183 pvalues.set(i, j, Double.NaN);
1184 fStats.set(i, j, Double.NaN);
1185 continue;
1186 }
1187
1188 fStats.set(i, j, fStats.get(i, j) / denominator.get(i));
1189 try {
1190 FDistribution pf = new FDistribution(ndof, rdof + this.dfPrior);
1191 pvalues.set(i, j, 1.0 - pf.cumulativeProbability(fStats.get(i, j)));
1192 } catch (NotStrictlyPositiveException e) {
1193 if (timesWarned < 10) {
1194 log.warn("Pvalue could not be computed for F=" + fStats.get(i, j) + "; denominator was="
1195 + denominator.get(i) + "; Error: " + e.getMessage()
1196 + " (limited warnings of this type will be given)");
1197 timesWarned++;
1198 }
1199 pvalues.set(i, j, Double.NaN);
1200 }
1201
1202 }
1203 }
1204 }
1205
1206
1207
1208
1209
1210 private String diagnosis(QRDecomposition qrd) {
1211 StringBuilder buf = new StringBuilder();
1212 buf.append("\n--------\nLM State\n--------\n");
1213 buf.append("hasMissing=" + this.hasMissing + "\n");
1214 buf.append("hasIntercept=" + this.hasIntercept + "\n");
1215 buf.append("Design: " + this.designMatrix + "\n");
1216 if (this.b.rows() < 5) {
1217 buf.append("Data matrix: " + this.b + "\n");
1218 } else {
1219 buf.append("Data (first few rows): " + this.b.viewSelection(new int[]{0, 1, 2, 3, 4}, null) + "\n");
1220
1221 }
1222 buf.append("Current QR:" + qrd + "\n");
1223 return buf.toString();
1224 }
1225
1226
1227
1228
1229 private void fit() {
1230 if (this.weights == null) {
1231 lsf();
1232 return;
1233 }
1234 wlsf();
1235 }
1236
1237
1238
1239
1240
1241 private QRDecomposition getQR(BitVector valuesPresent) {
1242 assert this.weights == null;
1243 return qrs.get(valuesPresent);
1244 }
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256 private QRDecomposition getQR(Integer row) {
1257 if (!this.hasMissing && this.weights == null) {
1258 return this.qr;
1259 }
1260
1261 if (this.weights != null)
1262 return qrsForWeighted.get(row);
1263
1264 assert this.hasMissing;
1265 BitVector key = valuesPresentMap.get(row);
1266 if (key == null) return null;
1267 return qrs.get(key);
1268
1269 }
1270
1271
1272
1273
1274 private void lsf() {
1275
1276 assert this.weights == null;
1277
1278 checkForMissingValues();
1279 Algebra solver = new Algebra();
1280
1281 if (this.hasMissing) {
1282 double[][] rawResult = new double[b.rows()][];
1283 for (int i = 0; i < b.rows(); i++) {
1284
1285 DoubleMatrix1D row = b.viewRow(i);
1286 if (row.size() < 3) {
1287 rawResult[i] = new double[A.columns()];
1288 continue;
1289 }
1290 DoubleMatrix1D withoutMissing = lsfWmissing(i, row, A);
1291 if (withoutMissing == null) {
1292 rawResult[i] = new double[A.columns()];
1293 } else {
1294 rawResult[i] = withoutMissing.toArray();
1295 }
1296 }
1297 this.coefficients = new DenseDoubleMatrix2D(rawResult).viewDice();
1298
1299 } else {
1300
1301 this.qr = new QRDecomposition(A);
1302 this.coefficients = qr.solve(solver.transpose(b));
1303 this.residualDof = b.columns() - qr.getRank();
1304 if (residualDof <= 0) {
1305 throw new IllegalArgumentException(
1306 "No residual degrees of freedom to fit the model" + diagnosis(qr));
1307 }
1308
1309 }
1310 assert this.assign.isEmpty() || this.assign.size() == this.coefficients.rows() : assign.size()
1311 + " != # coefficients " + this.coefficients.rows();
1312
1313 assert this.coefficients.rows() == A.columns();
1314
1315
1316 this.fitted = solver.transpose(MatrixUtil.multWithMissing(A, coefficients));
1317
1318 if (this.hasMissing) {
1319 MatrixUtil.maskMissing(b, fitted);
1320 }
1321
1322 this.residuals = b.copy().assign(fitted, Functions.minus);
1323 }
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336 private DoubleMatrix1D lsfWmissing(Integer row, DoubleMatrix1D y, DoubleMatrix2D des) {
1337 Algebra solver = new Algebra();
1338
1339
1340 List<Double> ywithoutMissingList = new ArrayList<>(y.size());
1341 int size = y.size();
1342 boolean hasAssign = !this.assign.isEmpty();
1343 int countNonMissing = 0;
1344 for (int i = 0; i < size; i++) {
1345 double v = y.getQuick(i);
1346 if (!Double.isNaN(v) && !Double.isInfinite(v)) {
1347 countNonMissing++;
1348 }
1349 }
1350
1351 if (countNonMissing < 3) {
1352
1353
1354
1355 DoubleMatrix1D re = new DenseDoubleMatrix1D(des.columns());
1356 re.assign(Double.NaN);
1357 log.debug("Not enough non-missing values");
1358 this.addQR(row, null, null);
1359 this.residualDofs.add(countNonMissing - des.columns());
1360 if (hasAssign) this.assigns.add(new ArrayList<Integer>());
1361 return re;
1362 }
1363
1364 double[][] rawDesignWithoutMissing = new double[countNonMissing][];
1365 int index = 0;
1366 boolean missing = false;
1367
1368 BitVector bv = new BitVector(size);
1369 for (int i = 0; i < size; i++) {
1370 double yi = y.getQuick(i);
1371 if (Double.isNaN(yi) || Double.isInfinite(yi)) {
1372 missing = true;
1373 continue;
1374 }
1375 ywithoutMissingList.add(yi);
1376 bv.set(i);
1377 rawDesignWithoutMissing[index++] = des.viewRow(i).toArray();
1378 }
1379 double[] yWithoutMissing = ArrayUtils.toPrimitive(ywithoutMissingList.toArray(new Double[]{}));
1380 DenseDoubleMatrix2D yWithoutMissingAsMatrix = new DenseDoubleMatrix2D(new double[][]{yWithoutMissing});
1381
1382 DoubleMatrix2D designWithoutMissing = new DenseDoubleMatrix2D(rawDesignWithoutMissing);
1383
1384 boolean fail = false;
1385 List<Integer> droppedColumns = new ArrayList<>();
1386 designWithoutMissing = this.cleanDesign(designWithoutMissing, yWithoutMissingAsMatrix.size(), droppedColumns);
1387
1388 if (designWithoutMissing.columns() == 0 || designWithoutMissing.columns() > designWithoutMissing.rows()) {
1389 fail = true;
1390 }
1391
1392 if (fail) {
1393 DoubleMatrix1D re = new DenseDoubleMatrix1D(des.columns());
1394 re.assign(Double.NaN);
1395 this.addQR(row, null, null);
1396 this.residualDofs.add(countNonMissing - des.columns());
1397 if (hasAssign) this.assigns.add(new ArrayList<Integer>());
1398 return re;
1399 }
1400
1401 QRDecomposition rqr = null;
1402 if (this.weights != null) {
1403 rqr = new QRDecomposition(designWithoutMissing);
1404 addQR(row, null, rqr);
1405 } else if (missing) {
1406 rqr = this.getQR(bv);
1407 if (rqr == null) {
1408 rqr = new QRDecomposition(designWithoutMissing);
1409 addQR(row, bv, rqr);
1410 }
1411 } else {
1412
1413
1414 if (this.qr == null) {
1415 rqr = new QRDecomposition(des);
1416 } else {
1417
1418 rqr = this.qr;
1419 }
1420 }
1421
1422 this.addQR(row, bv, rqr);
1423
1424 int pivots = rqr.getRank();
1425
1426 int rdof = yWithoutMissingAsMatrix.size() - pivots;
1427 this.residualDofs.add(rdof);
1428
1429 DoubleMatrix2D coefs = rqr.solve(solver.transpose(yWithoutMissingAsMatrix));
1430
1431
1432
1433
1434 if (designWithoutMissing.columns() < des.columns()) {
1435 DoubleMatrix1D col = coefs.viewColumn(0);
1436 DoubleMatrix1D result = new DenseDoubleMatrix1D(des.columns());
1437 result.assign(Double.NaN);
1438 int k = 0;
1439 List<Integer> assignForRow = new ArrayList<>();
1440 for (int i = 0; i < des.columns(); i++) {
1441 if (droppedColumns.contains(i)) {
1442
1443 continue;
1444 }
1445
1446 if (hasAssign) assignForRow.add(this.assign.get(i));
1447 assert k < col.size();
1448 result.set(i, col.get(k));
1449 k++;
1450 }
1451 if (hasAssign) assigns.add(assignForRow);
1452 return result;
1453 }
1454 if (hasAssign) assigns.add(this.assign);
1455 return coefs.viewColumn(0);
1456
1457 }
1458
1459
1460
1461
1462
1463
1464
1465
1466 private List<GenericAnovaResult> summarizeAnova(DoubleMatrix2D ssq, DoubleMatrix2D dof, DoubleMatrix2D fStats,
1467 DoubleMatrix2D pvalues) {
1468
1469 assert ssq != null;
1470 assert dof != null;
1471 assert fStats != null;
1472 assert pvalues != null;
1473
1474 List<GenericAnovaResult> results = new ArrayList<>();
1475 for (int i = 0; i < fStats.rows(); i++) {
1476 Collection<AnovaEffect> efs = new ArrayList<>();
1477
1478
1479
1480
1481 for (int j = 0; j < fStats.columns() - 1; j++) {
1482 String effectName = terms.get(j);
1483 assert effectName != null;
1484 AnovaEffectvaEffect.html#AnovaEffect">AnovaEffect ae = new AnovaEffect(effectName, pvalues.get(i, j), fStats.get(i, j), dof.get(
1485 i, j), ssq.get(i, j), effectName.contains(":"));
1486 efs.add(ae);
1487 }
1488
1489
1490
1491
1492 int residCol = fStats.columns() - 1;
1493 AnovaEffectvaEffect.html#AnovaEffect">AnovaEffect ae = new AnovaEffect("Residual", null, null, dof.get(i, residCol) + this.dfPrior, ssq.get(i,
1494 residCol), false);
1495 efs.add(ae);
1496
1497 GenericAnovaResultricAnovaResult.html#GenericAnovaResult">GenericAnovaResult ao = new GenericAnovaResult(efs);
1498 if (this.rowNames != null) ao.setKey(this.rowNames.get(i));
1499 results.add(ao);
1500 }
1501 return results;
1502 }
1503
1504
1505
1506
1507 private void wlsf() {
1508
1509 assert this.weights != null;
1510
1511 checkForMissingValues();
1512 Algebra solver = new Algebra();
1513
1514
1515
1516
1517 List<DoubleMatrix2D> AwList = new ArrayList<>(b.rows());
1518 List<DoubleMatrix1D> bList = new ArrayList<>(b.rows());
1519
1520
1521
1522
1523
1524
1525
1526
1527 for (int i = 0; i < b.rows(); i++) {
1528 DoubleMatrix1D wts = this.weights.viewRow(i).copy().assign(Functions.sqrt);
1529 DoubleMatrix1D bw = b.viewRow(i).copy().assign(wts, Functions.mult);
1530 DoubleMatrix2D Aw = A.copy();
1531 for (int j = 0; j < Aw.columns(); j++) {
1532 Aw.viewColumn(j).assign(wts, Functions.mult);
1533 }
1534 AwList.add(Aw);
1535 bList.add(bw);
1536 }
1537
1538 double[][] rawResult = new double[b.rows()][];
1539
1540 if (this.hasMissing) {
1541
1542
1543
1544 for (int i = 0; i < b.rows(); i++) {
1545 DoubleMatrix1D bw = bList.get(i);
1546 DoubleMatrix2D Aw = AwList.get(i);
1547 DoubleMatrix1D withoutMissing = lsfWmissing(i, bw, Aw);
1548 if (withoutMissing == null) {
1549 rawResult[i] = new double[A.columns()];
1550 } else {
1551 rawResult[i] = withoutMissing.toArray();
1552 }
1553 }
1554
1555 } else {
1556
1557
1558
1559 for (int i = 0; i < b.rows(); i++) {
1560 DoubleMatrix1D bw = bList.get(i);
1561 DoubleMatrix2D Aw = AwList.get(i);
1562 DoubleMatrix2D bw2D = new DenseDoubleMatrix2D(1, bw.size());
1563 bw2D.viewRow(0).assign(bw);
1564 QRDecompositionosition.html#QRDecomposition">QRDecomposition wqr = new QRDecomposition(Aw);
1565
1566
1567 this.addQR(i, null, wqr);
1568
1569 rawResult[i] = wqr.solve(solver.transpose(bw2D)).viewColumn(0).toArray();
1570 this.residualDof = bw.size() - wqr.getRank();
1571 assert this.residualDof >= 0;
1572 if (residualDof == 0) {
1573 throw new IllegalArgumentException("No residual degrees of freedom to fit the model"
1574 + diagnosis(wqr));
1575 }
1576 }
1577 }
1578
1579 this.coefficients = solver.transpose(new DenseDoubleMatrix2D(rawResult));
1580
1581 assert this.assign.isEmpty() || this.assign.size() == this.coefficients.rows() : assign.size()
1582 + " != # coefficients " + this.coefficients.rows();
1583 assert this.coefficients.rows() == A.columns();
1584
1585 this.fitted = solver.transpose(MatrixUtil.multWithMissing(A, coefficients));
1586
1587 if (this.hasMissing) {
1588 MatrixUtil.maskMissing(b, fitted);
1589 }
1590
1591 this.residuals = b.copy().assign(fitted, Functions.minus);
1592 }
1593
1594 }