View Javadoc
1   /*
2    * The baseCode project
3    *
4    * Copyright (c) 2011 University of British Columbia
5    *
6    * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
7    * the License. You may obtain a copy of the License at
8    *
9    * http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
12   * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
13   * specific language governing permissions and limitations under the License.
14   */
15  package ubic.basecode.math.linearmodels;
16  
17  import static cern.jet.math.Functions.chain;
18  import static cern.jet.math.Functions.div;
19  import static cern.jet.math.Functions.log2;
20  import static cern.jet.math.Functions.minus;
21  import static cern.jet.math.Functions.mult;
22  import static cern.jet.math.Functions.plus;
23  import static cern.jet.math.Functions.sqrt;
24  
25  import java.util.List;
26  
27  import cern.colt.function.IntIntDoubleFunction;
28  import cern.colt.list.DoubleArrayList;
29  import cern.colt.list.IntArrayList;
30  import cern.colt.matrix.DoubleMatrix1D;
31  import cern.colt.matrix.DoubleMatrix2D;
32  import cern.colt.matrix.impl.DenseDoubleMatrix1D;
33  import cern.colt.matrix.impl.DenseDoubleMatrix2D;
34  import cern.colt.matrix.linalg.Algebra;
35  import ubic.basecode.dataStructure.matrix.DoubleMatrix;
36  import ubic.basecode.math.DescriptiveWithMissing;
37  import ubic.basecode.math.MatrixRowStats;
38  import ubic.basecode.math.Smooth;
39  import ubic.basecode.math.linalg.QRDecomposition;
40  
41  /**
42   * Estimate mean-variance relationship and use this to compute weights for least squares fitting. R's limma.voom()
43   * Charity Law and Gordon Smyth. See Law et al.
44   * {@see http://genomebiology.biomedcentral.com/articles/10.1186/gb-2014-15-2-r29}
45   * <p>
46   * Running voom() on data matrices with NaNs is not currently supported.
47   *
48   * @author ptan
49   */
50  public class MeanVarianceEstimator {
51  
52      /**
53       * Normalized variables on log2 scale
54       */
55      private final DoubleMatrix2D E;
56  
57      /**
58       * Size of each library (column).
59       */
60      private DoubleMatrix1D librarySize;
61  
62      /**
63       * Loess fit (x, y)
64       */
65      private DoubleMatrix2D loess;
66  
67      /**
68       * Matrix that contains the mean and variance of the data. Matrix is sorted by increasing mean. Useful for plotting.
69       * mean <- fit$Amean + mean(log2(lib.size+1)) - log2(1e6) variance <- sqrt(fit$sigma)
70       */
71      private DoubleMatrix2D meanVariance;
72  
73      /**
74       * inverse variance weights
75       */
76      private DoubleMatrix2D weights = null;
77  
78      /**
79       * Preferred interface if you want control over how the design is set up. Executes voom() to calculate weights.
80       *
81       * @param designMatrix
82       * @param data         expected to be log2cpm, and already filtered
83       * @param librarySize  library size (matrix column sum)
84       */
85      public MeanVarianceEstimator(DesignMatrix designMatrix, DoubleMatrix<String, String> data,
86                                   DoubleMatrix1D librarySize) {
87  
88          DoubleMatrix2D b = new DenseDoubleMatrix2D(data.asArray());
89          this.librarySize = librarySize;
90          this.E = b;
91          voom(designMatrix.getDoubleMatrix());
92      }
93  
94      /**
95       * Executes voom() to calculate weights.
96       *
97       * @param designMatrix
98       * @param data         a normalized count matrix
99       * @param librarySize  library size (matrix column sum)
100      */
101     public MeanVarianceEstimator(DesignMatrix designMatrix, DoubleMatrix2D data, DoubleMatrix1D librarySize) {
102 
103         this.librarySize = librarySize;
104         this.E = data;
105         voom(designMatrix.getDoubleMatrix());
106     }
107 
108     /**
109      * Generic method for calculating mean and variance, as a data diagnostic.
110      * <p>
111      * voom() is not executed and therefore no weights
112      * are calculated.
113      *
114      * @param data data to be processed
115      */
116     public MeanVarianceEstimator(DoubleMatrix2D data) {
117         this.E = data;
118         mv();
119         this.loess = Smooth.loessFit(this.meanVariance);
120     }
121 
122     /**
123      * @return total library size
124      */
125     public DoubleMatrix1D getLibrarySize() {
126         return this.librarySize;
127     }
128 
129     /**
130      * @return the loess fit of the mean-variance relationship
131      */
132     public DoubleMatrix2D getLoess() {
133         return this.loess;
134     }
135 
136     /**
137      * @return the mean and variance of the normalized data, columns 0 and 1 respectively
138      */
139     public DoubleMatrix2D getMeanVariance() {
140         return this.meanVariance;
141     }
142 
143     /**
144      * @return data, as supplied (should be log2cpm)
145      */
146     public DoubleMatrix2D getNormalizedValue() {
147         return this.E;
148     }
149 
150     /**
151      * @return inverse variance weights if voom was applied
152      */
153     public DoubleMatrix2D getWeights() {
154         return this.weights;
155     }
156 
157 
158     /**
159      * Compute row-wise mean (x) and variance (y) on the given data. Nothing is regressed out and no loess is computed.
160      */
161     private void mv() {
162         assert this.E != null;
163 
164         // mean-variance
165         DoubleMatrix1D Amean = new DenseDoubleMatrix1D(E.rows());
166         DoubleMatrix1D variance = Amean.like();
167         for (int i = 0; i < Amean.size(); i++) {
168             DoubleArrayList row = new DoubleArrayList(E.viewRow(i).toArray());
169             double rowMean = DescriptiveWithMissing.mean(row);
170             double rowVar = DescriptiveWithMissing.variance(row);
171             Amean.set(i, rowMean);
172             variance.set(i, rowVar);
173 
174         }
175         this.meanVariance = new DenseDoubleMatrix2D(E.rows(), 2);
176         this.meanVariance.viewColumn(0).assign(Amean);
177         this.meanVariance.viewColumn(1).assign(variance);
178     }
179 
180     /**
181      * Performs the heavy duty work of calculating the weights. See Law et al.
182      * {@see http://genomebiology.biomedcentral.com/articles/10.1186/gb-2014-15-2-r29}. Assumes E is already log2cpm.
183      *
184      * @param designMatrix of factors that will be regressed out for the purpose of getting the mv relation
185      * @throws IllegalArgumentException if there are missing values.
186      */
187     private void voom(DoubleMatrix2D designMatrix) {
188         assert designMatrix != null;
189         assert this.E != null;
190         assert this.librarySize != null;
191 
192         Algebra solver = new Algebra();
193 
194         DoubleMatrix2D A = designMatrix;
195         weights = new DenseDoubleMatrix2D(E.rows(), E.columns());
196 
197         // perform a linear fit to obtain the mean-variance relationship
198         // fit3<-lm(t(yCpm) ~ as.matrix(design.matrix[,2]))
199         // or gFit <- lmFit(yCpm, design=design.matrix)
200         // as per voom, "Fit linear model to log2-counts-per-million"
201         LeastSquaresFitastSquaresFit.html#LeastSquaresFit">LeastSquaresFit lsf = new LeastSquaresFit(A, E);
202 
203 
204         // calculate fit$Amean by doing rowSums(CPM) (see limma.getEAWP())
205         DoubleMatrix1D Amean = MatrixRowStats.means(E);
206 
207         // "Fit lowess trend to sqrt-standard-deviations by log-count-size" (so we have to convert back)
208         // sx <- fit$Amean + mean(log2(lib.size+1)) - log2(1e6)
209         DoubleMatrix1D sx = Amean.copy();
210         double meanLog2LibrarySize = librarySize.copy().assign(chain(log2, plus(1))).zSum() / librarySize.size();
211         sx.assign(plus(meanLog2LibrarySize));
212         sx.assign(minus(Math.log(Math.pow(10, 6)) / Math.log(2)));
213 
214         DoubleMatrix1D sy = quarterRootVariance(lsf);
215 
216         // only accepts array in strictly increasing order (drop duplicates)
217         // so combine sx and sy and sort
218         DoubleMatrix2D voomXY = new DenseDoubleMatrix2D(sx.size(), 2);
219         voomXY.viewColumn(0).assign(sx);
220         voomXY.viewColumn(1).assign(sy);
221         DoubleMatrix2D fit = Smooth.loessFit(voomXY);
222         this.meanVariance = voomXY;
223         this.loess = fit;
224 
225         // quarterroot fitted counts
226         DoubleMatrix2D fittedValues = null;
227         QRDecompositionposition.html#QRDecomposition">QRDecomposition qr = new QRDecomposition(A);
228         DoubleMatrix2D coeff = lsf.getCoefficients();
229 
230         if (qr.getRank() < A.columns()) {
231             // j <- fit$pivot[1:fit$rank]
232             // fitted.values <- fit$coef[,j,drop=F] %*% t(fit$design[,j,drop=F]);
233             IntArrayList pivot = qr.getPivotOrder();
234 
235             IntArrayList subindices = (IntArrayList) pivot.partFromTo(0, qr.getRank() - 1);
236             int[] coeffAllCols = new int[coeff.columns()];
237             int[] desAllRows = new int[A.rows()];
238             for (int i = 0; i < coeffAllCols.length; i++) {
239                 coeffAllCols[i] = i;
240             }
241             for (int i = 0; i < desAllRows.length; i++) {
242                 desAllRows[i] = i;
243             }
244             DoubleMatrix2D coeffSlice = coeff.viewSelection(subindices.elements(), coeffAllCols);
245             DoubleMatrix2D ASlice = A.viewSelection(desAllRows, subindices.elements());
246             fittedValues = solver.mult(coeffSlice.viewDice(), ASlice.viewDice());
247         } else {
248             // fitted.values <- fit$coef %*% t(fit$design)
249             fittedValues = solver.mult(coeff.viewDice(), A.viewDice());
250         }
251 
252         // back-compute the values we want
253         // fitted.cpm <- 2^fitted.values
254         // fitted.count <- 1e-6 * t(t(fitted.cpm)*(lib.size+1))
255         // fitted.logcount <- log2(fitted.count)
256         DoubleMatrix2D fittedCpm = fittedValues.copy().forEachNonZero(new IntIntDoubleFunction() {
257             @Override
258             public double apply(int row, int column, double third) {
259                 return Math.pow(2, third);
260             }
261         });
262         DoubleMatrix2D fittedCount = fittedCpm.copy();
263         DoubleMatrix1D libSizePlusOne = librarySize.assign(plus(1));
264         for (int i = 0; i < fittedCount.rows(); i++) {
265             fittedCount.viewRow(i).assign(libSizePlusOne, mult);
266             fittedCount.viewRow(i).assign(mult(Math.pow(10, -6)));
267         }
268         DoubleMatrix2D fittedLogCount = fittedCount.copy().assign(log2);
269 
270         // to here we are *very* close to limma::voom
271 
272         // interpolate points using the loess curve
273         // f <- approxfun(l, rule=2)
274         // 2D to 1D FIXME this unrolling seems kind of unnecessary
275         double[] xInterpolate = new double[fittedLogCount.rows() * fittedLogCount.columns()];
276         int idx = 0;
277         for (int col = 0; col < fittedLogCount.columns(); col++) {
278             for (int row = 0; row < fittedLogCount.rows(); row++) {
279                 xInterpolate[idx] = fittedLogCount.get(row, col);
280                 idx++;
281             }
282         }
283         // apply trend to individual observations
284         // w <- 1/f(fitted.logcount)^4
285         assert fit != null;
286         double[] yInterpolate = Smooth.interpolate(fit.viewColumn(0).toArray(),
287                 fit.viewColumn(1).toArray(), xInterpolate);
288 
289         // 1D to 2D
290         idx = 0;
291         for (int col = 0; col < weights.columns(); col++) {
292             for (int row = 0; row < weights.rows(); row++) {
293                 weights.set(row, col, (1.0 / Math.pow(yInterpolate[idx], 4)));
294                 idx++;
295             }
296         }
297 
298     }
299 
300     protected DoubleMatrix1D quarterRootVariance(LeastSquaresFit lsf) {
301 
302         // help("MArrayLM-class")
303         // fit$sigma <- sqrt(sum(out$residuals^2)/out$df.residual)
304         // sy <- sqrt(fit$sigma)
305         DoubleMatrix2D residuals = lsf.getResiduals();
306         DoubleMatrix1D sy = new DenseDoubleMatrix1D(residuals.rows());
307         // sum squared residuals
308         for (int row = 0; row < residuals.rows(); row++) {
309             double sum = 0;
310             for (int column = 0; column < residuals.columns(); column++) {
311                 Double val = residuals.get(row, column);
312                 if (!Double.isNaN(val)) {
313                     sum += val * val;
314                 }
315             }
316             sy.set(row, sum);
317         }
318         // sigma (ssq/n)
319         // if you have missing values in the expression matrix
320         // you'll get a residual dof of 0
321         if (lsf.isHasMissing()) {
322             // calculate it per row
323             List<Integer> dofs = lsf.getResidualDofs();
324             assert dofs.size() == sy.size();
325             for (int i = 0; i < sy.size(); i++) {
326                 sy.set(i, Math.sqrt(sy.get(i) / dofs.get(i)));
327             }
328         } else {
329             int dof = lsf.getResidualDof();
330             assert dof != 0;
331             sy.assign(chain(sqrt, div(dof)));
332         }
333         // we're fitting the quarter-root variances.
334         sy.assign(sqrt);
335         return sy;
336     }
337 }