1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package ubic.basecode.math.linearmodels;
16
17 import static cern.jet.math.Functions.chain;
18 import static cern.jet.math.Functions.div;
19 import static cern.jet.math.Functions.log2;
20 import static cern.jet.math.Functions.minus;
21 import static cern.jet.math.Functions.mult;
22 import static cern.jet.math.Functions.plus;
23 import static cern.jet.math.Functions.sqrt;
24
25 import java.util.List;
26
27 import cern.colt.function.IntIntDoubleFunction;
28 import cern.colt.list.DoubleArrayList;
29 import cern.colt.list.IntArrayList;
30 import cern.colt.matrix.DoubleMatrix1D;
31 import cern.colt.matrix.DoubleMatrix2D;
32 import cern.colt.matrix.impl.DenseDoubleMatrix1D;
33 import cern.colt.matrix.impl.DenseDoubleMatrix2D;
34 import cern.colt.matrix.linalg.Algebra;
35 import ubic.basecode.dataStructure.matrix.DoubleMatrix;
36 import ubic.basecode.math.DescriptiveWithMissing;
37 import ubic.basecode.math.MatrixRowStats;
38 import ubic.basecode.math.Smooth;
39 import ubic.basecode.math.linalg.QRDecomposition;
40
41
42
43
44
45
46
47
48
49
50 public class MeanVarianceEstimator {
51
52
53
54
55 private final DoubleMatrix2D E;
56
57
58
59
60 private DoubleMatrix1D librarySize;
61
62
63
64
65 private DoubleMatrix2D loess;
66
67
68
69
70
71 private DoubleMatrix2D meanVariance;
72
73
74
75
76 private DoubleMatrix2D weights = null;
77
78
79
80
81
82
83
84
85 public MeanVarianceEstimator(DesignMatrix designMatrix, DoubleMatrix<String, String> data,
86 DoubleMatrix1D librarySize) {
87
88 DoubleMatrix2D b = new DenseDoubleMatrix2D(data.asArray());
89 this.librarySize = librarySize;
90 this.E = b;
91 voom(designMatrix.getDoubleMatrix());
92 }
93
94
95
96
97
98
99
100
101 public MeanVarianceEstimator(DesignMatrix designMatrix, DoubleMatrix2D data, DoubleMatrix1D librarySize) {
102
103 this.librarySize = librarySize;
104 this.E = data;
105 voom(designMatrix.getDoubleMatrix());
106 }
107
108
109
110
111
112
113
114
115
116 public MeanVarianceEstimator(DoubleMatrix2D data) {
117 this.E = data;
118 mv();
119 this.loess = Smooth.loessFit(this.meanVariance);
120 }
121
122
123
124
125 public DoubleMatrix1D getLibrarySize() {
126 return this.librarySize;
127 }
128
129
130
131
132 public DoubleMatrix2D getLoess() {
133 return this.loess;
134 }
135
136
137
138
139 public DoubleMatrix2D getMeanVariance() {
140 return this.meanVariance;
141 }
142
143
144
145
146 public DoubleMatrix2D getNormalizedValue() {
147 return this.E;
148 }
149
150
151
152
153 public DoubleMatrix2D getWeights() {
154 return this.weights;
155 }
156
157
158
159
160
161 private void mv() {
162 assert this.E != null;
163
164
165 DoubleMatrix1D Amean = new DenseDoubleMatrix1D(E.rows());
166 DoubleMatrix1D variance = Amean.like();
167 for (int i = 0; i < Amean.size(); i++) {
168 DoubleArrayList row = new DoubleArrayList(E.viewRow(i).toArray());
169 double rowMean = DescriptiveWithMissing.mean(row);
170 double rowVar = DescriptiveWithMissing.variance(row);
171 Amean.set(i, rowMean);
172 variance.set(i, rowVar);
173
174 }
175 this.meanVariance = new DenseDoubleMatrix2D(E.rows(), 2);
176 this.meanVariance.viewColumn(0).assign(Amean);
177 this.meanVariance.viewColumn(1).assign(variance);
178 }
179
180
181
182
183
184
185
186
187 private void voom(DoubleMatrix2D designMatrix) {
188 assert designMatrix != null;
189 assert this.E != null;
190 assert this.librarySize != null;
191
192 Algebra solver = new Algebra();
193
194 DoubleMatrix2D A = designMatrix;
195 weights = new DenseDoubleMatrix2D(E.rows(), E.columns());
196
197
198
199
200
201 LeastSquaresFitastSquaresFit.html#LeastSquaresFit">LeastSquaresFit lsf = new LeastSquaresFit(A, E);
202
203
204
205 DoubleMatrix1D Amean = MatrixRowStats.means(E);
206
207
208
209 DoubleMatrix1D sx = Amean.copy();
210 double meanLog2LibrarySize = librarySize.copy().assign(chain(log2, plus(1))).zSum() / librarySize.size();
211 sx.assign(plus(meanLog2LibrarySize));
212 sx.assign(minus(Math.log(Math.pow(10, 6)) / Math.log(2)));
213
214 DoubleMatrix1D sy = quarterRootVariance(lsf);
215
216
217
218 DoubleMatrix2D voomXY = new DenseDoubleMatrix2D(sx.size(), 2);
219 voomXY.viewColumn(0).assign(sx);
220 voomXY.viewColumn(1).assign(sy);
221 DoubleMatrix2D fit = Smooth.loessFit(voomXY);
222 this.meanVariance = voomXY;
223 this.loess = fit;
224
225
226 DoubleMatrix2D fittedValues = null;
227 QRDecompositionposition.html#QRDecomposition">QRDecomposition qr = new QRDecomposition(A);
228 DoubleMatrix2D coeff = lsf.getCoefficients();
229
230 if (qr.getRank() < A.columns()) {
231
232
233 IntArrayList pivot = qr.getPivotOrder();
234
235 IntArrayList subindices = (IntArrayList) pivot.partFromTo(0, qr.getRank() - 1);
236 int[] coeffAllCols = new int[coeff.columns()];
237 int[] desAllRows = new int[A.rows()];
238 for (int i = 0; i < coeffAllCols.length; i++) {
239 coeffAllCols[i] = i;
240 }
241 for (int i = 0; i < desAllRows.length; i++) {
242 desAllRows[i] = i;
243 }
244 DoubleMatrix2D coeffSlice = coeff.viewSelection(subindices.elements(), coeffAllCols);
245 DoubleMatrix2D ASlice = A.viewSelection(desAllRows, subindices.elements());
246 fittedValues = solver.mult(coeffSlice.viewDice(), ASlice.viewDice());
247 } else {
248
249 fittedValues = solver.mult(coeff.viewDice(), A.viewDice());
250 }
251
252
253
254
255
256 DoubleMatrix2D fittedCpm = fittedValues.copy().forEachNonZero(new IntIntDoubleFunction() {
257 @Override
258 public double apply(int row, int column, double third) {
259 return Math.pow(2, third);
260 }
261 });
262 DoubleMatrix2D fittedCount = fittedCpm.copy();
263 DoubleMatrix1D libSizePlusOne = librarySize.assign(plus(1));
264 for (int i = 0; i < fittedCount.rows(); i++) {
265 fittedCount.viewRow(i).assign(libSizePlusOne, mult);
266 fittedCount.viewRow(i).assign(mult(Math.pow(10, -6)));
267 }
268 DoubleMatrix2D fittedLogCount = fittedCount.copy().assign(log2);
269
270
271
272
273
274
275 double[] xInterpolate = new double[fittedLogCount.rows() * fittedLogCount.columns()];
276 int idx = 0;
277 for (int col = 0; col < fittedLogCount.columns(); col++) {
278 for (int row = 0; row < fittedLogCount.rows(); row++) {
279 xInterpolate[idx] = fittedLogCount.get(row, col);
280 idx++;
281 }
282 }
283
284
285 assert fit != null;
286 double[] yInterpolate = Smooth.interpolate(fit.viewColumn(0).toArray(),
287 fit.viewColumn(1).toArray(), xInterpolate);
288
289
290 idx = 0;
291 for (int col = 0; col < weights.columns(); col++) {
292 for (int row = 0; row < weights.rows(); row++) {
293 weights.set(row, col, (1.0 / Math.pow(yInterpolate[idx], 4)));
294 idx++;
295 }
296 }
297
298 }
299
300 protected DoubleMatrix1D quarterRootVariance(LeastSquaresFit lsf) {
301
302
303
304
305 DoubleMatrix2D residuals = lsf.getResiduals();
306 DoubleMatrix1D sy = new DenseDoubleMatrix1D(residuals.rows());
307
308 for (int row = 0; row < residuals.rows(); row++) {
309 double sum = 0;
310 for (int column = 0; column < residuals.columns(); column++) {
311 Double val = residuals.get(row, column);
312 if (!Double.isNaN(val)) {
313 sum += val * val;
314 }
315 }
316 sy.set(row, sum);
317 }
318
319
320
321 if (lsf.isHasMissing()) {
322
323 List<Integer> dofs = lsf.getResidualDofs();
324 assert dofs.size() == sy.size();
325 for (int i = 0; i < sy.size(); i++) {
326 sy.set(i, Math.sqrt(sy.get(i) / dofs.get(i)));
327 }
328 } else {
329 int dof = lsf.getResidualDof();
330 assert dof != 0;
331 sy.assign(chain(sqrt, div(dof)));
332 }
333
334 sy.assign(sqrt);
335 return sy;
336 }
337 }