1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  package ubic.basecode.math.linearmodels;
16  
17  import java.io.File;
18  import java.io.IOException;
19  import java.io.OutputStream;
20  import java.io.PrintStream;
21  import java.util.ArrayList;
22  import java.util.Arrays;
23  import java.util.Collection;
24  import java.util.HashMap;
25  import java.util.LinkedHashMap;
26  import java.util.List;
27  import java.util.Map;
28  import java.util.Set;
29  import java.util.TreeMap;
30  import java.util.TreeSet;
31  
32  import org.apache.commons.lang3.ArrayUtils;
33  import org.apache.commons.lang3.StringUtils;
34  import org.apache.commons.lang3.time.StopWatch;
35  import org.apache.commons.math3.distribution.FDistribution;
36  import org.apache.commons.math3.distribution.TDistribution;
37  import org.apache.commons.math3.exception.NotStrictlyPositiveException;
38  import org.slf4j.Logger;
39  import org.slf4j.LoggerFactory;
40  
41  import cern.colt.bitvector.BitVector;
42  import cern.colt.list.DoubleArrayList;
43  import cern.colt.matrix.DoubleMatrix1D;
44  import cern.colt.matrix.DoubleMatrix2D;
45  import cern.colt.matrix.impl.DenseDoubleMatrix1D;
46  import cern.colt.matrix.impl.DenseDoubleMatrix2D;
47  import cern.colt.matrix.linalg.Algebra;
48  import cern.jet.math.Functions;
49  import cern.jet.stat.Descriptive;
50  import ubic.basecode.dataStructure.matrix.DoubleMatrix;
51  import ubic.basecode.dataStructure.matrix.DoubleMatrixFactory;
52  import ubic.basecode.dataStructure.matrix.MatrixUtil;
53  import ubic.basecode.dataStructure.matrix.ObjectMatrix;
54  import ubic.basecode.math.Constants;
55  import ubic.basecode.math.linalg.QRDecomposition;
56  import ubic.basecode.util.r.type.AnovaEffect;
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  public class LeastSquaresFit {
68  
69      private static Logger log = LoggerFactory.getLogger(LeastSquaresFit.class);
70  
71      
72  
73  
74      boolean hasBeenShrunken = false;
75  
76      
77  
78  
79      private DoubleMatrix2D A;
80  
81      
82  
83  
84  
85      private List<Integer> assign = new ArrayList<>();
86  
87      
88  
89  
90      private List<List<Integer>> assigns = new ArrayList<>();
91  
92      
93  
94  
95      private DoubleMatrix2D b;
96  
97      
98  
99  
100     private DoubleMatrix2D coefficients = null;
101 
102     
103 
104 
105     private DesignMatrix designMatrix;
106 
107     
108 
109 
110     private double dfPrior = 0;
111 
112     
113 
114 
115     private DoubleMatrix2D fitted;
116 
117     
118 
119 
120     private boolean hasIntercept = true;
121 
122     
123 
124 
125     private boolean hasMissing = false;
126 
127     
128 
129 
130     private QRDecomposition qr = null;
131 
132     
133 
134 
135 
136 
137     private Map<BitVector, QRDecomposition> qrs = new HashMap<>();
138 
139     
140 
141 
142 
143     private Map<Integer, QRDecomposition> qrsForWeighted = new HashMap<>();
144 
145     private int residualDof = -1;
146 
147     
148 
149 
150     private List<Integer> residualDofs = new ArrayList<>();
151 
152     
153 
154 
155     private DoubleMatrix2D residuals = null;
156 
157     
158 
159 
160     private List<String> rowNames;
161 
162     
163 
164 
165     private Map<Integer, DoubleMatrix1D> stdevUnscaled = new TreeMap<>();
166 
167     
168 
169 
170     private List<String> terms;
171 
172     
173 
174 
175     private Map<Integer, BitVector> valuesPresentMap = new HashMap<>();
176 
177     
178 
179 
180     private DoubleMatrix1D varPost = null;
181 
182     
183 
184 
185     private Double varPrior = null;
186 
187     
188 
189 
190     private DoubleMatrix2D weights = null;
191 
192     
193 
194 
195 
196 
197 
198     public LeastSquaresFit(DesignMatrix designMatrix, DoubleMatrix<String, String> data) {
199         this.designMatrix = designMatrix;
200         this.A = designMatrix.getDoubleMatrix();
201         this.assign = designMatrix.getAssign();
202         this.terms = designMatrix.getTerms();
203 
204         this.rowNames = data.getRowNames();
205         this.b = new DenseDoubleMatrix2D(data.asArray());
206         boolean hasInterceptTerm = this.terms.contains(LinearModelSummary.INTERCEPT_COEFFICIENT_NAME);
207         this.hasIntercept = designMatrix.hasIntercept();
208         assert hasInterceptTerm == this.hasIntercept : diagnosis(null);
209         fit();
210     }
211 
212     
213 
214 
215 
216 
217 
218 
219     public LeastSquaresFit(DesignMatrix designMatrix, DoubleMatrix<String, String> data,
220                            final DoubleMatrix2D weights) {
221         this.designMatrix = designMatrix;
222         DoubleMatrix2D X = designMatrix.getDoubleMatrix();
223         this.assign = designMatrix.getAssign();
224         this.terms = designMatrix.getTerms();
225         this.A = X;
226         this.rowNames = data.getRowNames();
227         this.b = new DenseDoubleMatrix2D(data.asArray());
228         boolean hasInterceptTerm = this.terms.contains(LinearModelSummary.INTERCEPT_COEFFICIENT_NAME);
229         this.hasIntercept = designMatrix.hasIntercept();
230         assert hasInterceptTerm == this.hasIntercept : diagnosis(null);
231         this.weights = weights;
232         fit();
233     }
234 
235     
236 
237 
238 
239 
240 
241 
242     public LeastSquaresFit(DesignMatrix designMatrix, DoubleMatrix2D b, final DoubleMatrix2D weights) {
243 
244         this.designMatrix = designMatrix;
245         DoubleMatrix2D X = designMatrix.getDoubleMatrix();
246         this.assign = designMatrix.getAssign();
247         this.terms = designMatrix.getTerms();
248         this.A = X;
249         this.b = b;
250         boolean hasInterceptTerm = this.terms.contains(LinearModelSummary.INTERCEPT_COEFFICIENT_NAME);
251         this.hasIntercept = designMatrix.hasIntercept();
252         assert hasInterceptTerm == this.hasIntercept : diagnosis(null);
253 
254         this.weights = weights;
255 
256         fit();
257 
258     }
259 
260     
261 
262 
263 
264 
265 
266     public LeastSquaresFit(DoubleMatrix1D vectorA, DoubleMatrix1D vectorB) {
267         assert vectorA.size() == vectorB.size();
268 
269         this.A = new DenseDoubleMatrix2D(vectorA.size(), 2);
270         this.b = new DenseDoubleMatrix2D(1, vectorB.size());
271 
272         for (int i = 0; i < vectorA.size(); i++) {
273             A.set(i, 0, 1);
274             A.set(i, 1, vectorA.get(i));
275             b.set(0, i, vectorB.get(i));
276         }
277 
278         fit();
279     }
280 
281     
282 
283 
284 
285 
286 
287 
288     public LeastSquaresFit(DoubleMatrix1D vectorA, DoubleMatrix1D vectorB, final DoubleMatrix1D weights) {
289 
290         assert vectorA.size() == vectorB.size();
291         assert vectorA.size() == weights.size();
292 
293         this.A = new DenseDoubleMatrix2D(vectorA.size(), 2);
294         this.b = new DenseDoubleMatrix2D(1, vectorB.size());
295         this.weights = new DenseDoubleMatrix2D(1, weights.size());
296 
297         for (int i = 0; i < vectorA.size(); i++) {
298             
299             A.set(i, 0, 1);
300             A.set(i, 1, vectorA.get(i));
301             b.set(0, i, vectorB.get(i));
302             this.weights.set(0, i, weights.get(i));
303         }
304 
305         fit();
306     }
307 
308     
309 
310 
311 
312 
313 
314     public LeastSquaresFit(DoubleMatrix2D A, DoubleMatrix2D b) {
315         this.A = A;
316         this.b = b;
317         fit();
318     }
319 
320     
321 
322 
323 
324 
325 
326 
327     public LeastSquaresFit(DoubleMatrix2D A, DoubleMatrix2D b, final DoubleMatrix2D weights) {
328         assert A != null;
329         assert b != null;
330         assert A.rows() == b.columns();
331         assert weights == null || b.columns() == weights.columns();
332         assert weights == null || b.rows() == weights.rows();
333 
334         this.A = A;
335         this.b = b;
336         this.weights = weights;
337 
338         fit();
339 
340     }
341 
342     
343 
344 
345 
346     public LeastSquaresFit(ObjectMatrix<String, String, Object> sampleInfo, DenseDoubleMatrix2D data) {
347 
348         this.designMatrix = new DesignMatrix(sampleInfo, true);
349 
350         this.hasIntercept = true;
351         this.A = designMatrix.getDoubleMatrix();
352         this.assign = designMatrix.getAssign();
353         this.terms = designMatrix.getTerms();
354 
355         this.b = data;
356         fit();
357     }
358 
359     
360 
361 
362 
363 
364     public LeastSquaresFit(ObjectMatrix<String, String, Object> sampleInfo, DenseDoubleMatrix2D data,
365                            boolean interactions) {
366         this.designMatrix = new DesignMatrix(sampleInfo, true);
367 
368         if (interactions) {
369             addInteraction();
370         }
371 
372         this.A = designMatrix.getDoubleMatrix();
373         this.assign = designMatrix.getAssign();
374         this.terms = designMatrix.getTerms();
375 
376         this.b = data;
377         fit();
378     }
379 
380     
381 
382 
383 
384 
385 
386     public LeastSquaresFit(ObjectMatrix<String, String, Object> design, DoubleMatrix<String, String> b) {
387         this.designMatrix = new DesignMatrix(design, true);
388 
389         this.A = designMatrix.getDoubleMatrix();
390         this.assign = designMatrix.getAssign();
391         this.terms = designMatrix.getTerms();
392 
393         this.b = new DenseDoubleMatrix2D(b.asArray());
394         this.rowNames = b.getRowNames();
395         fit();
396     }
397 
398     
399 
400 
401 
402 
403 
404     public LeastSquaresFit(ObjectMatrix<String, String, Object> design, DoubleMatrix<String, String> data,
405                            boolean interactions) {
406         this.designMatrix = new DesignMatrix(design, true);
407 
408         if (interactions) {
409             addInteraction();
410         }
411 
412         DoubleMatrix2D X = designMatrix.getDoubleMatrix();
413         this.assign = designMatrix.getAssign();
414         this.terms = designMatrix.getTerms();
415         this.A = X;
416         this.b = new DenseDoubleMatrix2D(data.asArray());
417         fit();
418     }
419 
420     
421 
422 
423 
424 
425 
426     public DoubleMatrix2D getCoefficients() {
427         return coefficients;
428     }
429 
430     public double getDfPrior() {
431         return dfPrior;
432     }
433 
434     public DoubleMatrix2D getFitted() {
435         return fitted;
436     }
437 
438     public int getResidualDof() {
439         return residualDof;
440     }
441 
442     public List<Integer> getResidualDofs() {
443         return residualDofs;
444     }
445 
446     public DoubleMatrix2D getResiduals() {
447         return residuals;
448     }
449 
450     
451 
452 
453     public DoubleMatrix2D getStudentizedResiduals() {
454         int dof = this.residualDof - 1; 
455 
456         assert dof > 0;
457 
458         if (this.hasMissing) {
459             throw new UnsupportedOperationException("Studentizing not supported with missing values");
460         }
461 
462         DoubleMatrix2D result = this.residuals.like();
463 
464         
465 
466 
467         DoubleMatrix2D q = this.getQR(0).getQ();
468 
469         DoubleMatrix1D hatdiag = new DenseDoubleMatrix1D(residuals.columns());
470         for (int j = 0; j < residuals.columns(); j++) {
471             double hj = q.viewRow(j).aggregate(Functions.plus, Functions.square);
472             if (1.0 - hj < Constants.TINY) {
473                 hj = 1.0;
474             }
475             hatdiag.set(j, hj);
476         }
477 
478         
479 
480 
481         for (int i = 0; i < residuals.rows(); i++) {
482 
483             
484             
485 
486             DoubleMatrix1D residualRow = residuals.viewRow(i);
487 
488             if (this.weights != null) {
489                 
490                 DoubleMatrix1D w = weights.viewRow(i).copy().assign(Functions.sqrt);
491                 residualRow = residualRow.copy().assign(w, Functions.mult);
492             }
493 
494             double sum = residualRow.aggregate(Functions.plus, Functions.square);
495 
496             for (int j = 0; j < residualRow.size(); j++) {
497 
498                 double hj = hatdiag.get(j);
499 
500                 
501                 double sigma;
502 
503                 if (hj < 1.0) {
504                     sigma = Math.sqrt((sum - Math.pow(residualRow.get(j), 2) / (1.0 - hj)) / dof);
505                 } else {
506                     sigma = Math.sqrt(sum / dof);
507                 }
508 
509                 double res = residualRow.getQuick(j);
510                 double studres = res / (sigma * Math.sqrt(1.0 - hj));
511 
512                 if (log.isDebugEnabled()) log.debug("sigma=" + sigma + " hj=" + hj + " stres=" + studres);
513 
514                 result.set(i, j, studres);
515             }
516         }
517         return result;
518     }
519 
520     public DoubleMatrix1D getVarPost() {
521         return varPost;
522     }
523 
524     public double getVarPrior() {
525         return varPrior;
526     }
527 
528     public DoubleMatrix2D getWeights() {
529         return weights;
530     }
531 
532     public boolean isHasBeenShrunken() {
533         return hasBeenShrunken;
534     }
535 
536     public boolean isHasMissing() {
537         return hasMissing;
538     }
539 
540     
541 
542 
543 
544     public List<LinearModelSummary> summarize() {
545         return this.summarize(false);
546     }
547 
548     
549 
550 
551 
552     public List<LinearModelSummary> summarize(boolean anova) {
553         List<LinearModelSummary> lmsresults = new ArrayList<>();
554 
555         List<GenericAnovaResult> anovas = null;
556         if (anova) {
557             anovas = this.anova();
558         }
559 
560         StopWatch timer = new StopWatch();
561         timer.start();
562         log.info("Summarizing");
563         for (int i = 0; i < this.coefficients.columns(); i++) {
564             LinearModelSummary lms = summarize(i);
565             lms.setAnova(anovas != null ? anovas.get(i) : null);
566             lmsresults.add(lms);
567             if (timer.getTime() > 10000 && i > 0 && i % 10000 == 0) {
568                 log.info("Summarized " + i);
569             }
570         }
571         log.info("Summzarized " + this.coefficients.columns() + " results");
572 
573         return lmsresults;
574     }
575 
576     
577 
578 
579 
580 
581 
582     public Map<String, LinearModelSummary> summarizeByKeys(boolean anova) {
583         List<LinearModelSummary> summaries = this.summarize(anova);
584         Map<String, LinearModelSummary> result = new LinkedHashMap<>();
585         for (LinearModelSummary lms : summaries) {
586             if (StringUtils.isBlank(lms.getKey())) {
587                 
588 
589 
590                 throw new IllegalStateException("Key must not be blank");
591             }
592 
593             if (result.containsKey(lms.getKey())) {
594                 throw new IllegalStateException("Duplicate key " + lms.getKey());
595             }
596             result.put(lms.getKey(), lms);
597         }
598         return result;
599     }
600 
601     
602 
603 
604 
605 
606 
607 
608 
609 
610     protected List<GenericAnovaResult> anova() {
611 
612         DoubleMatrix1D ones = new DenseDoubleMatrix1D(residuals.columns());
613         ones.assign(1.0);
614 
615         
616 
617 
618         DoubleMatrix1D residualSumsOfSquares;
619 
620         if (this.weights == null) {
621             residualSumsOfSquares = MatrixUtil.multWithMissing(residuals.copy().assign(Functions.square),
622                     ones);
623         } else {
624             residualSumsOfSquares = MatrixUtil.multWithMissing(
625                     residuals.copy().assign(this.weights.copy().assign(Functions.sqrt), Functions.mult).assign(Functions.square),
626                     ones);
627         }
628 
629         DoubleMatrix2D effects = null;
630         if (this.hasMissing || this.weights != null) {
631             effects = new DenseDoubleMatrix2D(this.b.rows(), this.A.columns());
632             effects.assign(Double.NaN);
633             for (int i = 0; i < this.b.rows(); i++) {
634                 QRDecomposition qrd = this.getQR(i);
635                 if (qrd == null) {
636                     
637                     for (int j = 0; j < effects.columns(); j++) {
638                         effects.set(i, j, Double.NaN);
639                     }
640                     continue;
641                 }
642 
643                 
644 
645 
646                 DoubleMatrix1D brow = b.viewRow(i);
647                 DoubleMatrix1D browWithoutMissing = MatrixUtil.removeMissingOrInfinite(brow);
648 
649                 DoubleMatrix1D tqty;
650                 if (weights != null) {
651                     DoubleMatrix1D w = MatrixUtil.removeMissingOrInfinite(brow, this.weights.viewRow(i).copy().assign(Functions.sqrt));
652                     assert w.size() == browWithoutMissing.size();
653                     DoubleMatrix1D bw = browWithoutMissing.copy().assign(w, Functions.mult);
654                     tqty = qrd.effects(bw);
655                 } else {
656                     tqty = qrd.effects(browWithoutMissing);
657                 }
658 
659                 
660                 for (int j = 0; j < qrd.getRank(); j++) {
661                     effects.set(i, j, tqty.get(j));
662                 }
663             }
664 
665         } else {
666             assert this.qr != null;
667             effects = qr.effects(this.b.viewDice().copy()).viewDice();
668         }
669 
670         
671         effects.assign(Functions.square); 
672 
673         
674 
675 
676         Set<Integer> facs = new TreeSet<>();
677         facs.addAll(assign);
678 
679         DoubleMatrix2D ssq = new DenseDoubleMatrix2D(effects.rows(), facs.size() + 1);
680         DoubleMatrix2D dof = new DenseDoubleMatrix2D(effects.rows(), facs.size() + 1);
681         dof.assign(0.0);
682         ssq.assign(0.0);
683         List<Integer> assignToUse = assign; 
684 
685         for (int i = 0; i < ssq.rows(); i++) {
686 
687             ssq.set(i, facs.size(), residualSumsOfSquares.get(i));
688             int rdof;
689             if (this.residualDofs.isEmpty()) {
690                 rdof = this.residualDof;
691             } else {
692                 rdof = this.residualDofs.get(i);
693             }
694             
695 
696 
697             dof.set(i, facs.size(), rdof);
698 
699             if (!assigns.isEmpty()) {
700                 
701                 assignToUse = assigns.get(i);
702             }
703 
704             
705             DoubleMatrix1D effectsForRow = effects.viewRow(i);
706 
707             if (assignToUse.size() != effectsForRow.size()) {
708                 
709 
710 
711                 log.debug("Check me: effects has missing values");
712             }
713 
714             for (int j = 0; j < assignToUse.size(); j++) {
715 
716                 double valueToAdd = effectsForRow.get(j);
717                 int col = assignToUse.get(j);
718                 if (col > 0 && !this.hasIntercept) {
719                     col = col - 1;
720                 }
721 
722                 
723 
724 
725 
726 
727                 if (!Double.isNaN(valueToAdd) && valueToAdd > Constants.SMALL) {
728                     ssq.set(i, col, ssq.get(i, col) + valueToAdd);
729                     dof.set(i, col, dof.get(i, col) + 1);
730                 }
731             }
732         }
733 
734         DoubleMatrix1D denominator;
735         if (this.hasBeenShrunken) {
736             denominator = this.varPost.copy();
737         } else {
738             if (this.residualDofs.isEmpty()) {
739                 
740                 denominator = residualSumsOfSquares.copy().assign(Functions.div(residualDof));
741             } else {
742                 denominator = new DenseDoubleMatrix1D(residualSumsOfSquares.size());
743                 for (int i = 0; i < residualSumsOfSquares.size(); i++) {
744                     denominator.set(i, residualSumsOfSquares.get(i) / residualDofs.get(i));
745                 }
746             }
747         }
748 
749         
750         DoubleMatrix2D fStats = ssq.copy().assign(dof, Functions.div); 
751         DoubleMatrix2D pvalues = fStats.like();
752         computeStats(dof, fStats, denominator, pvalues);
753 
754         return summarizeAnova(ssq, dof, fStats, pvalues);
755     }
756 
757     
758 
759 
760 
761 
762 
763 
764     protected void ebayesUpdate(double d, double v, DoubleMatrix1D vp) {
765         this.dfPrior = d;
766         this.varPrior = v; 
767         this.varPost = vp; 
768         this.hasBeenShrunken = true;
769     }
770 
771     
772 
773 
774 
775 
776 
777 
778 
779 
780 
781 
782     protected LinearModelSummary summarize(int i) {
783 
784         String key = null;
785         if (this.rowNames != null) {
786             key = this.rowNames.get(i);
787             if (key == null) log.warn("Key null at " + i);
788         }
789 
790         QRDecomposition qrd = null;
791         qrd = this.getQR(i);
792 
793         if (qrd == null) {
794             log.debug("QR was null for item " + i);
795             return new LinearModelSummary(key);
796         }
797 
798         int rdf;
799         if (this.residualDofs.isEmpty()) {
800             rdf = this.residualDof; 
801         } else {
802             rdf = this.residualDofs.get(i); 
803         }
804         assert !Double.isNaN(rdf);
805 
806         if (rdf == 0) {
807             return new LinearModelSummary(key);
808         }
809 
810         DoubleMatrix1D resid = MatrixUtil.removeMissingOrInfinite(this.residuals.viewRow(i));
811         DoubleMatrix1D f = MatrixUtil.removeMissingOrInfinite(fitted.viewRow(i));
812 
813         DoubleMatrix1D rweights = null;
814         DoubleMatrix1D sqrtweights = null;
815         if (this.weights != null) {
816             rweights = MatrixUtil.removeMissingOrInfinite(fitted.viewRow(i), this.weights.viewRow(i).copy());
817             sqrtweights = rweights.copy().assign(Functions.sqrt);
818         } else {
819             rweights = new DenseDoubleMatrix1D(f.size()).assign(1.0);
820             sqrtweights = rweights.copy();
821         }
822 
823         DoubleMatrix1D allCoef = coefficients.viewColumn(i); 
824         DoubleMatrix1D estCoef = MatrixUtil.removeMissingOrInfinite(allCoef); 
825 
826         if (estCoef.size() == 0) {
827             log.warn("No coefficients estimated for row " + i + this.diagnosis(qrd));
828             log.info("Data for this row:\n" + this.b.viewRow(i));
829             return new LinearModelSummary(key);
830         }
831 
832         int rank = qrd.getRank();
833         int n = qrd.getQ().rows();
834         assert rdf == n - rank : "Rank was not correct, expected " + rdf + " but got Q rows=" + n + ", #Coef=" + rank
835                 + diagnosis(qrd);
836 
837         
838         
839         
840         
841         
842         
843         
844         
845         
846         
847         
848         
849         double mss;
850 
851         if (weights != null) {
852 
853             if (hasIntercept) {
854                 
855                 double m = f.copy().assign(Functions.div(rweights.zSum())).assign(rweights, Functions.mult).zSum();
856 
857                 mss = f.copy().assign(Functions.minus(m)).assign(Functions.square).assign(rweights, Functions.mult).zSum();
858             } else {
859                 mss = f.copy().assign(Functions.square).assign(rweights, Functions.mult).zSum();
860             }
861 
862             assert resid.size() == rweights.size();
863         } else {
864             if (hasIntercept) {
865                 mss = f.copy().assign(Functions.minus(Descriptive.mean(new DoubleArrayList(f.toArray()))))
866                         .assign(Functions.square).zSum();
867             } else {
868                 mss = f.copy().assign(Functions.square).zSum();
869             }
870         }
871 
872         double rss = resid.copy().assign(Functions.square).assign(rweights, Functions.mult).zSum();
873         if (weights != null) resid = resid.copy().assign(sqrtweights, Functions.mult);
874 
875         double resvar = rss / rdf; 
876 
877         
878         DoubleMatrix<String, String> summaryTable = DoubleMatrixFactory.dense(allCoef.size(), 4);
879         summaryTable.assign(Double.NaN);
880         summaryTable
881                 .setColumnNames(Arrays.asList(new String[]{"Estimate", "Std. Error", "t value", "Pr(>|t|)"}));
882 
883         
884 
885         DoubleMatrix2D XtXi = qrd.chol2inv();
886 
887         
888         DoubleMatrix1D sdUnscaled = MatrixUtil.diagonal(XtXi).assign(Functions.sqrt);
889 
890         
891         
892 
893         this.stdevUnscaled.put(i, sdUnscaled);
894 
895         DoubleMatrix1D sdScaled = MatrixUtil
896                 .removeMissingOrInfinite(MatrixUtil.diagonal(XtXi).assign(Functions.mult(resvar))
897                         .assign(Functions.sqrt)); 
898 
899         
900 
901         DoubleMatrix1D effects = qrd.effects(MatrixUtil.removeMissingOrInfinite(this.b.viewRow(i).copy()).assign(sqrtweights, Functions.mult));
902 
903         
904         
905         
906         
907         
908         
909 
910         
911         double sigma = Math.sqrt(
912                 effects.copy().viewPart(rank, effects.size() - rank).aggregate(Functions.plus, Functions.square) / (effects.size() - rank));
913 
914         
915 
916 
917 
918         DoubleMatrix1D tstats;
919         TDistribution tdist;
920         if (this.hasBeenShrunken) {
921             
922 
923 
924 
925             tstats = estCoef.copy().assign(sdUnscaled, Functions.div).assign(
926                     Functions.div(Math.sqrt(this.varPost.get(i))));
927 
928             
929 
930 
931 
932 
933 
934 
935 
936             double dfTotal = rdf + this.dfPrior;
937 
938             assert !Double.isNaN(dfTotal);
939             tdist = new TDistribution(dfTotal);
940         } else {
941             
942 
943 
944 
945 
946             tstats = estCoef.copy().assign(sdScaled, Functions.div);
947             tdist = new TDistribution(rdf);
948         }
949 
950         int j = 0;
951         for (int ti = 0; ti < allCoef.size(); ti++) {
952             double c = allCoef.get(ti);
953             assert this.designMatrix != null;
954             List<String> colNames = this.designMatrix.getMatrix().getColNames();
955 
956             String dmcolname;
957             if (colNames == null) {
958                 dmcolname = "Column_" + ti;
959             } else {
960                 dmcolname = colNames.get(ti);
961             }
962 
963             summaryTable.addRowName(dmcolname);
964             if (Double.isNaN(c)) {
965                 continue;
966             }
967 
968             summaryTable.set(ti, 0, estCoef.get(j));
969             summaryTable.set(ti, 1, sdUnscaled.get(j));
970             summaryTable.set(ti, 2, tstats.get(j));
971 
972             double pval = 2.0 * (1.0 - tdist.cumulativeProbability(Math.abs(tstats.get(j))));
973             summaryTable.set(ti, 3, pval);
974 
975             j++;
976 
977         }
978 
979         double rsquared = 0.0;
980         double adjRsquared = 0.0;
981         double fstatistic = 0.0;
982         int numdf = 0;
983         int dendf = 0;
984 
985         if (terms.size() > 1 || !hasIntercept) {
986             int dfint = hasIntercept ? 1 : 0;
987             rsquared = mss / (mss + rss);
988             adjRsquared = 1 - (1 - rsquared) * ((n - dfint) / (double) rdf);
989 
990             fstatistic = mss / (rank - dfint) / resvar;
991 
992             
993             numdf = rank - dfint;
994             dendf = rdf;
995 
996         } else {
997             
998             rsquared = 0.0;
999             adjRsquared = 0.0;
1000         }
1001 
1002         
1003         
1004         LinearModelSummary lms = new LinearModelSummary(key, ArrayUtils.toObject(allCoef.toArray()),
1005                 ArrayUtils.toObject(resid
1006                         .toArray()),
1007                 terms,
1008                 summaryTable, ArrayUtils.toObject(effects.toArray()),
1009                 ArrayUtils.toObject(sdUnscaled.toArray()), rsquared,
1010                 adjRsquared,
1011                 fstatistic,
1012                 numdf, dendf, null, sigma, this.hasBeenShrunken);
1013         lms.setPriorDof(this.dfPrior);
1014 
1015         return lms;
1016     }
1017 
1018     
1019 
1020 
1021     private void addInteraction() {
1022         if (designMatrix.getTerms().size() == 1) {
1023             throw new IllegalArgumentException("Need at least two factors for interactions");
1024         }
1025         if (designMatrix.getTerms().size() != 2) {
1026             throw new UnsupportedOperationException("Interactions not supported for more than two factors");
1027         }
1028         this.designMatrix.addInteraction(designMatrix.getTerms().get(0), designMatrix.getTerms().get(1));
1029     }
1030 
1031     
1032 
1033 
1034 
1035 
1036 
1037 
1038     private void addQR(Integer row, BitVector valuesPresent, QRDecomposition newQR) {
1039 
1040         
1041 
1042 
1043 
1044 
1045         if (this.weights != null) {
1046             this.qrsForWeighted.put(row, newQR);
1047             return;
1048         }
1049 
1050         assert row != null;
1051 
1052         if (valuesPresent == null) {
1053             valuesPresentMap.put(row, null);
1054         }
1055 
1056         QRDecomposition cachedQr = qrs.get(valuesPresent);
1057         if (cachedQr == null) {
1058             qrs.put(valuesPresent, newQR);
1059         }
1060         valuesPresentMap.put(row, valuesPresent);
1061     }
1062 
1063     
1064 
1065 
1066     private void checkForMissingValues() {
1067         for (int i = 0; i < b.rows(); i++) {
1068             for (int j = 0; j < b.columns(); j++) {
1069                 double v = b.get(i, j);
1070                 if (Double.isNaN(v) || Double.isInfinite(v)) {
1071                     this.hasMissing = true;
1072                     log.info("Data has missing values (at row=" + (i + 1) + " column=" + (j + 1));
1073                     break;
1074                 }
1075             }
1076             if (this.hasMissing) break;
1077         }
1078     }
1079 
1080     
1081 
1082 
1083 
1084 
1085 
1086 
1087 
1088 
1089 
1090 
1091 
1092 
1093     private DoubleMatrix2D cleanDesign(final DoubleMatrix2D design, int ypsize, List<Integer> droppedColumns) {
1094 
1095         
1096 
1097 
1098         for (int j = 0; j < design.columns(); j++) {
1099             if (j == 0 && this.hasIntercept) continue;
1100             double lastValue = Double.NaN;
1101             boolean constant = true;
1102             for (int i = 0; i < design.rows(); i++) {
1103                 double thisvalue = design.get(i, j);
1104                 if (i > 0 && thisvalue != lastValue) {
1105                     constant = false;
1106                     break;
1107                 }
1108                 lastValue = thisvalue;
1109             }
1110             if (constant) {
1111                 log.debug("Dropping constant column " + j);
1112                 droppedColumns.add(j);
1113                 continue;
1114             }
1115 
1116             DoubleMatrix1D col = design.viewColumn(j);
1117 
1118             for (int p = 0; p < j; p++) {
1119                 boolean redundant = true;
1120                 DoubleMatrix1D otherCol = design.viewColumn(p);
1121                 for (int v = 0; v < col.size(); v++) {
1122                     if (col.get(v) != otherCol.get(v)) {
1123                         redundant = false;
1124                         break;
1125                     }
1126                 }
1127                 if (redundant) {
1128                     log.debug("Dropping redundant column " + j);
1129                     droppedColumns.add(j);
1130                     break;
1131                 }
1132             }
1133 
1134         }
1135 
1136         DoubleMatrix2D returnValue = MatrixUtil.dropColumns(design, droppedColumns);
1137 
1138         return returnValue;
1139     }
1140 
1141     
1142 
1143 
1144 
1145 
1146 
1147 
1148 
1149     private void computeStats(DoubleMatrix2D dof, DoubleMatrix2D fStats, DoubleMatrix1D denominator,
1150                               DoubleMatrix2D pvalues) {
1151         pvalues.assign(Double.NaN);
1152         int timesWarned = 0;
1153         for (int i = 0; i < fStats.rows(); i++) {
1154 
1155             int rdof;
1156             if (this.residualDofs.isEmpty()) {
1157                 rdof = residualDof;
1158             } else {
1159                 rdof = this.residualDofs.get(i);
1160             }
1161 
1162             for (int j = 0; j < fStats.columns(); j++) {
1163 
1164                 double ndof = dof.get(i, j);
1165 
1166                 if (ndof <= 0 || rdof <= 0) {
1167                     pvalues.set(i, j, Double.NaN);
1168                     fStats.set(i, j, Double.NaN);
1169                     continue;
1170                 }
1171 
1172                 if (j == fStats.columns() - 1) {
1173                     
1174                     pvalues.set(i, j, Double.NaN);
1175                     fStats.set(i, j, Double.NaN);
1176                     continue;
1177                 }
1178 
1179                 
1180 
1181 
1182                 if (fStats.get(i, j) < Constants.SMALLISH && denominator.get(i) < Constants.SMALLISH) {
1183                     pvalues.set(i, j, Double.NaN);
1184                     fStats.set(i, j, Double.NaN);
1185                     continue;
1186                 }
1187 
1188                 fStats.set(i, j, fStats.get(i, j) / denominator.get(i));
1189                 try {
1190                     FDistribution pf = new FDistribution(ndof, rdof + this.dfPrior);
1191                     pvalues.set(i, j, 1.0 - pf.cumulativeProbability(fStats.get(i, j)));
1192                 } catch (NotStrictlyPositiveException e) {
1193                     if (timesWarned < 10) {
1194                         log.warn("Pvalue could not be computed for F=" + fStats.get(i, j) + "; denominator was="
1195                                 + denominator.get(i) + "; Error: " + e.getMessage()
1196                                 + " (limited warnings of this type will be given)");
1197                         timesWarned++;
1198                     }
1199                     pvalues.set(i, j, Double.NaN);
1200                 }
1201 
1202             }
1203         }
1204     }
1205 
1206     
1207 
1208 
1209 
1210     private String diagnosis(QRDecomposition qrd) {
1211         StringBuilder buf = new StringBuilder();
1212         buf.append("\n--------\nLM State\n--------\n");
1213         buf.append("hasMissing=" + this.hasMissing + "\n");
1214         buf.append("hasIntercept=" + this.hasIntercept + "\n");
1215         buf.append("Design: " + this.designMatrix + "\n");
1216         if (this.b.rows() < 5) {
1217             buf.append("Data matrix: " + this.b + "\n");
1218         } else {
1219             buf.append("Data (first few rows): " + this.b.viewSelection(new int[]{0, 1, 2, 3, 4}, null) + "\n");
1220 
1221         }
1222         buf.append("Current QR:" + qrd + "\n");
1223         return buf.toString();
1224     }
1225 
1226     
1227 
1228 
1229     private void fit() {
1230         if (this.weights == null) {
1231             lsf();
1232             return;
1233         }
1234         wlsf();
1235     }
1236 
1237     
1238 
1239 
1240 
1241     private QRDecomposition getQR(BitVector valuesPresent) {
1242         assert this.weights == null;
1243         return qrs.get(valuesPresent);
1244     }
1245 
1246     
1247 
1248 
1249 
1250 
1251 
1252 
1253 
1254 
1255 
1256     private QRDecomposition getQR(Integer row) {
1257         if (!this.hasMissing && this.weights == null) {
1258             return this.qr;
1259         }
1260 
1261         if (this.weights != null)
1262             return qrsForWeighted.get(row);
1263 
1264         assert this.hasMissing;
1265         BitVector key = valuesPresentMap.get(row);
1266         if (key == null) return null;
1267         return qrs.get(key);
1268 
1269     }
1270 
1271     
1272 
1273 
1274     private void lsf() {
1275 
1276         assert this.weights == null;
1277 
1278         checkForMissingValues();
1279         Algebra solver = new Algebra();
1280 
1281         if (this.hasMissing) {
1282             double[][] rawResult = new double[b.rows()][];
1283             for (int i = 0; i < b.rows(); i++) {
1284 
1285                 DoubleMatrix1D row = b.viewRow(i);
1286                 if (row.size() < 3) { 
1287                     rawResult[i] = new double[A.columns()];
1288                     continue;
1289                 }
1290                 DoubleMatrix1D withoutMissing = lsfWmissing(i, row, A);
1291                 if (withoutMissing == null) {
1292                     rawResult[i] = new double[A.columns()];
1293                 } else {
1294                     rawResult[i] = withoutMissing.toArray();
1295                 }
1296             }
1297             this.coefficients = new DenseDoubleMatrix2D(rawResult).viewDice();
1298 
1299         } else {
1300 
1301             this.qr = new QRDecomposition(A);
1302             this.coefficients = qr.solve(solver.transpose(b));
1303             this.residualDof = b.columns() - qr.getRank();
1304             if (residualDof <= 0) {
1305                 throw new IllegalArgumentException(
1306                         "No residual degrees of freedom to fit the model" + diagnosis(qr));
1307             }
1308 
1309         }
1310         assert this.assign.isEmpty() || this.assign.size() == this.coefficients.rows() : assign.size()
1311                 + " != # coefficients " + this.coefficients.rows();
1312 
1313         assert this.coefficients.rows() == A.columns();
1314 
1315         
1316         this.fitted = solver.transpose(MatrixUtil.multWithMissing(A, coefficients));
1317 
1318         if (this.hasMissing) {
1319             MatrixUtil.maskMissing(b, fitted);
1320         }
1321 
1322         this.residuals = b.copy().assign(fitted, Functions.minus);
1323     }
1324 
1325     
1326 
1327 
1328 
1329 
1330 
1331 
1332 
1333 
1334 
1335 
1336     private DoubleMatrix1D lsfWmissing(Integer row, DoubleMatrix1D y, DoubleMatrix2D des) {
1337         Algebra solver = new Algebra();
1338         
1339 
1340         List<Double> ywithoutMissingList = new ArrayList<>(y.size());
1341         int size = y.size();
1342         boolean hasAssign = !this.assign.isEmpty();
1343         int countNonMissing = 0;
1344         for (int i = 0; i < size; i++) {
1345             double v = y.getQuick(i);
1346             if (!Double.isNaN(v) && !Double.isInfinite(v)) {
1347                 countNonMissing++;
1348             }
1349         }
1350 
1351         if (countNonMissing < 3) {
1352             
1353 
1354 
1355             DoubleMatrix1D re = new DenseDoubleMatrix1D(des.columns());
1356             re.assign(Double.NaN);
1357             log.debug("Not enough non-missing values");
1358             this.addQR(row, null, null);
1359             this.residualDofs.add(countNonMissing - des.columns());
1360             if (hasAssign) this.assigns.add(new ArrayList<Integer>());
1361             return re;
1362         }
1363 
1364         double[][] rawDesignWithoutMissing = new double[countNonMissing][];
1365         int index = 0;
1366         boolean missing = false;
1367 
1368         BitVector bv = new BitVector(size);
1369         for (int i = 0; i < size; i++) {
1370             double yi = y.getQuick(i);
1371             if (Double.isNaN(yi) || Double.isInfinite(yi)) {
1372                 missing = true;
1373                 continue;
1374             }
1375             ywithoutMissingList.add(yi);
1376             bv.set(i);
1377             rawDesignWithoutMissing[index++] = des.viewRow(i).toArray();
1378         }
1379         double[] yWithoutMissing = ArrayUtils.toPrimitive(ywithoutMissingList.toArray(new Double[]{}));
1380         DenseDoubleMatrix2D yWithoutMissingAsMatrix = new DenseDoubleMatrix2D(new double[][]{yWithoutMissing});
1381 
1382         DoubleMatrix2D designWithoutMissing = new DenseDoubleMatrix2D(rawDesignWithoutMissing);
1383 
1384         boolean fail = false;
1385         List<Integer> droppedColumns = new ArrayList<>();
1386         designWithoutMissing = this.cleanDesign(designWithoutMissing, yWithoutMissingAsMatrix.size(), droppedColumns);
1387 
1388         if (designWithoutMissing.columns() == 0 || designWithoutMissing.columns() > designWithoutMissing.rows()) {
1389             fail = true;
1390         }
1391 
1392         if (fail) {
1393             DoubleMatrix1D re = new DenseDoubleMatrix1D(des.columns());
1394             re.assign(Double.NaN);
1395             this.addQR(row, null, null);
1396             this.residualDofs.add(countNonMissing - des.columns());
1397             if (hasAssign) this.assigns.add(new ArrayList<Integer>());
1398             return re;
1399         }
1400 
1401         QRDecomposition rqr = null;
1402         if (this.weights != null) {
1403             rqr = new QRDecomposition(designWithoutMissing);
1404             addQR(row, null, rqr);
1405         } else if (missing) {
1406             rqr = this.getQR(bv);
1407             if (rqr == null) {
1408                 rqr = new QRDecomposition(designWithoutMissing);
1409                 addQR(row, bv, rqr);
1410             }
1411         } else {
1412             
1413             
1414             if (this.qr == null) {
1415                 rqr = new QRDecomposition(des);
1416             } else {
1417                 
1418                 rqr = this.qr;
1419             }
1420         }
1421 
1422         this.addQR(row, bv, rqr);
1423 
1424         int pivots = rqr.getRank();
1425 
1426         int rdof = yWithoutMissingAsMatrix.size() - pivots;
1427         this.residualDofs.add(rdof);
1428 
1429         DoubleMatrix2D coefs = rqr.solve(solver.transpose(yWithoutMissingAsMatrix));
1430 
1431         
1432 
1433 
1434         if (designWithoutMissing.columns() < des.columns()) {
1435             DoubleMatrix1D col = coefs.viewColumn(0);
1436             DoubleMatrix1D result = new DenseDoubleMatrix1D(des.columns());
1437             result.assign(Double.NaN);
1438             int k = 0;
1439             List<Integer> assignForRow = new ArrayList<>();
1440             for (int i = 0; i < des.columns(); i++) {
1441                 if (droppedColumns.contains(i)) {
1442                     
1443                     continue;
1444                 }
1445 
1446                 if (hasAssign) assignForRow.add(this.assign.get(i));
1447                 assert k < col.size();
1448                 result.set(i, col.get(k));
1449                 k++;
1450             }
1451             if (hasAssign) assigns.add(assignForRow);
1452             return result;
1453         }
1454         if (hasAssign) assigns.add(this.assign);
1455         return coefs.viewColumn(0);
1456 
1457     }
1458 
1459     
1460 
1461 
1462 
1463 
1464 
1465 
1466     private List<GenericAnovaResult> summarizeAnova(DoubleMatrix2D ssq, DoubleMatrix2D dof, DoubleMatrix2D fStats,
1467                                                     DoubleMatrix2D pvalues) {
1468 
1469         assert ssq != null;
1470         assert dof != null;
1471         assert fStats != null;
1472         assert pvalues != null;
1473 
1474         List<GenericAnovaResult> results = new ArrayList<>();
1475         for (int i = 0; i < fStats.rows(); i++) {
1476             Collection<AnovaEffect> efs = new ArrayList<>();
1477 
1478             
1479 
1480 
1481             for (int j = 0; j < fStats.columns() - 1; j++) {
1482                 String effectName = terms.get(j);
1483                 assert effectName != null;
1484                 AnovaEffect ae = new AnovaEffect(effectName, pvalues.get(i, j), fStats.get(i, j), dof.get(
1485                         i, j), ssq.get(i, j), effectName.contains(":"));
1486                 efs.add(ae);
1487             }
1488 
1489             
1490 
1491 
1492             int residCol = fStats.columns() - 1;
1493             AnovaEffect ae = new AnovaEffect("Residual", null, null, dof.get(i, residCol) + this.dfPrior, ssq.get(i,
1494                     residCol), false);
1495             efs.add(ae);
1496 
1497             GenericAnovaResult ao = new GenericAnovaResult(efs);
1498             if (this.rowNames != null) ao.setKey(this.rowNames.get(i));
1499             results.add(ao);
1500         }
1501         return results;
1502     }
1503 
1504     
1505 
1506 
1507     private void wlsf() {
1508 
1509         assert this.weights != null;
1510 
1511         checkForMissingValues();
1512         Algebra solver = new Algebra();
1513 
1514         
1515 
1516 
1517         List<DoubleMatrix2D> AwList = new ArrayList<>(b.rows());
1518         List<DoubleMatrix1D> bList = new ArrayList<>(b.rows());
1519 
1520         
1521 
1522 
1523 
1524 
1525 
1526 
1527         for (int i = 0; i < b.rows(); i++) {
1528             DoubleMatrix1D wts = this.weights.viewRow(i).copy().assign(Functions.sqrt);
1529             DoubleMatrix1D bw = b.viewRow(i).copy().assign(wts, Functions.mult);
1530             DoubleMatrix2D Aw = A.copy();
1531             for (int j = 0; j < Aw.columns(); j++) {
1532                 Aw.viewColumn(j).assign(wts, Functions.mult);
1533             }
1534             AwList.add(Aw);
1535             bList.add(bw);
1536         }
1537 
1538         double[][] rawResult = new double[b.rows()][];
1539 
1540         if (this.hasMissing) {
1541             
1542 
1543 
1544             for (int i = 0; i < b.rows(); i++) {
1545                 DoubleMatrix1D bw = bList.get(i);
1546                 DoubleMatrix2D Aw = AwList.get(i);
1547                 DoubleMatrix1D withoutMissing = lsfWmissing(i, bw, Aw);
1548                 if (withoutMissing == null) {
1549                     rawResult[i] = new double[A.columns()];
1550                 } else {
1551                     rawResult[i] = withoutMissing.toArray();
1552                 }
1553             }
1554 
1555         } else {
1556 
1557             
1558             
1559             for (int i = 0; i < b.rows(); i++) {
1560                 DoubleMatrix1D bw = bList.get(i);
1561                 DoubleMatrix2D Aw = AwList.get(i);
1562                 DoubleMatrix2D bw2D = new DenseDoubleMatrix2D(1, bw.size());
1563                 bw2D.viewRow(0).assign(bw);
1564                 QRDecomposition wqr = new QRDecomposition(Aw);
1565 
1566                 
1567                 this.addQR(i, null, wqr);
1568 
1569                 rawResult[i] = wqr.solve(solver.transpose(bw2D)).viewColumn(0).toArray();
1570                 this.residualDof = bw.size() - wqr.getRank();
1571                 assert this.residualDof >= 0;
1572                 if (residualDof == 0) {
1573                     throw new IllegalArgumentException("No residual degrees of freedom to fit the model"
1574                             + diagnosis(wqr));
1575                 }
1576             }
1577         }
1578 
1579         this.coefficients = solver.transpose(new DenseDoubleMatrix2D(rawResult));
1580 
1581         assert this.assign.isEmpty() || this.assign.size() == this.coefficients.rows() : assign.size()
1582                 + " != # coefficients " + this.coefficients.rows();
1583         assert this.coefficients.rows() == A.columns();
1584 
1585         this.fitted = solver.transpose(MatrixUtil.multWithMissing(A, coefficients));
1586 
1587         if (this.hasMissing) {
1588             MatrixUtil.maskMissing(b, fitted);
1589         }
1590 
1591         this.residuals = b.copy().assign(fitted, Functions.minus);
1592     }
1593 
1594 }