1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package ubic.basecode.math.linalg;
16
17 import cern.colt.list.IntArrayList;
18 import cern.colt.matrix.DoubleMatrix1D;
19 import cern.colt.matrix.DoubleMatrix2D;
20 import cern.colt.matrix.impl.DenseDoubleMatrix2D;
21 import cern.jet.math.Functions;
22 import no.uib.cipr.matrix.DenseMatrix;
23 import no.uib.cipr.matrix.Matrices;
24 import org.apache.commons.lang3.ArrayUtils;
25 import org.apache.commons.lang3.StringUtils;
26 import org.netlib.lapack.Dpotri;
27 import org.netlib.util.intW;
28 import org.slf4j.Logger;
29 import org.slf4j.LoggerFactory;
30 import ubic.basecode.dataStructure.matrix.DenseDoubleMatrix1D;
31 import ubic.basecode.dataStructure.matrix.MatrixUtil;
32
33
34
35
36
37
38
39 public class QRDecomposition {
40
41 private static Logger log = LoggerFactory.getLogger( QRDecomposition.class );
42
43 private DenseDoubleMatrix2D chol2inv;
44
45 private int[] jpvt;
46
47
48
49
50 private int n;
51
52
53
54
55 private int p;
56
57
58
59
60 private boolean pivoting = true;
61
62
63
64
65
66 private DoubleMatrix2D QR;
67
68
69
70
71 private DoubleMatrix1D qraux;
72
73
74
75
76 private int rank = 0;
77
78
79
80
81 private DoubleMatrix1D Rdiag;
82
83
84
85
86 private double tolerance = 1e-7;
87
88
89
90
91 private DoubleMatrix2D Qcached = null;
92
93 private DoubleMatrix2D effects;
94
95
96
97
98 public QRDecomposition( final DoubleMatrix2D A ) {
99 this( A, true );
100 }
101
102
103
104
105
106
107
108 public QRDecomposition( final DoubleMatrix2D A, boolean pivoting ) {
109
110 this.QR = A.copy();
111 this.n = A.rows();
112 this.p = A.columns();
113 this.Rdiag = A.like1D( p );
114 this.pivoting = pivoting;
115
116
117 jpvt = new int[p];
118 for ( int i = 0; i < p; i++ ) {
119 jpvt[i] = i;
120 }
121
122
123 DoubleMatrix1D originalNorms;
124 originalNorms = new DenseDoubleMatrix1D( p );
125 qraux = new DenseDoubleMatrix1D( p );
126 for ( int i = 0; i < p; i++ ) {
127 DoubleMatrix1D col = QR.viewColumn( i );
128 double norm2 = 0;
129 for ( int r = 0; r < n; r++ ) {
130 norm2 = hypot( norm2, col.getQuick( r ) );
131 }
132 qraux.set( i, norm2 );
133 originalNorms.set( i, norm2 == 0.0 ? 1.0 : norm2 );
134 }
135
136
137 DoubleMatrix1D[] QRcolumns = new DoubleMatrix1D[p];
138 DoubleMatrix1D[] QRcolumnsPart = new DoubleMatrix1D[p];
139 for ( int v = 0; v < p; v++ ) {
140 QRcolumns[v] = QR.viewColumn( v );
141 QRcolumnsPart[v] = QR.viewColumn( v ).viewPart( v, n - v );
142 }
143
144 rank = p;
145
146
147 for ( int v = 0; v < p; v++ ) {
148
149
150
151
152
153
154 while ( pivoting && v < rank && qraux.get( v ) < originalNorms.get( v ) * tolerance ) {
155 log.debug( "Rotating " + v );
156 rotate( QR, originalNorms, v );
157 }
158
159 DoubleMatrix1D colv = QRcolumns[v];
160 double nrm = 0;
161 for ( int i = v; i < n; i++ ) {
162 nrm = hypot( nrm, colv.getQuick( i ) );
163 }
164
165
166
167
168
169
170 if ( nrm != 0.0 ) {
171
172 if ( QR.getQuick( v, v ) < 0 ) nrm = -nrm;
173 QRcolumnsPart[v].assign( Functions.div( nrm ) );
174
175 QR.setQuick( v, v, QR.getQuick( v, v ) + 1.0 );
176
177
178 for ( int j = v + 1; j < p; j++ ) {
179 DoubleMatrix1D QRcolj = QR.viewColumn( j ).viewPart( v, n - v );
180 double s = QRcolumnsPart[v].zDotProduct( QRcolj );
181
182 s = -s / QR.getQuick( v, v );
183 for ( int i = v; i < n; i++ ) {
184 QR.setQuick( i, j, QR.getQuick( i, j ) + s * QR.getQuick( i, v ) );
185 }
186
187 if ( qraux.getQuick( j ) == 0 ) {
188 continue;
189 }
190
191
192
193
194 double tt = QR.getQuick( v, j ) / qraux.getQuick( j );
195 double t = Math.max( 1.0 - Math.pow( tt, 2 ), 0.0 );
196 if ( t < 1e-6 ) {
197 DoubleMatrix1D col = QR.viewColumn( j );
198 double nrmv = 0.0;
199 for ( int r = v + 1; r < n; r++ ) {
200 nrmv = hypot( nrmv, col.getQuick( r ) );
201 }
202 qraux.set( j, nrmv );
203 } else {
204 qraux.set( j, qraux.getQuick( j ) * Math.sqrt( t ) );
205 }
206 }
207 if ( log.isDebugEnabled() ) log.debug( qraux.toString() );
208
209 }
210
211
212 qraux.set( v, QR.getQuick( v, v ) );
213 QR.setQuick( v, v, -nrm );
214 Rdiag.setQuick( v, -nrm );
215 }
216 rank = Math.min( rank, n );
217 }
218
219
220
221
222
223
224 public DoubleMatrix2D chol2inv() {
225
226 return dpotri( this.getR().viewPart( 0, 0, this.rank, this.rank ) );
227 }
228
229
230
231
232
233
234
235
236
237
238
239
240
241 public DoubleMatrix1D effects( DoubleMatrix1D y ) {
242
243 double[] qty = new double[y.size()];
244 double[] junk = new double[y.size()];
245 ubic.basecode.math.linalg.Dqrsl.dqrsl_j( QR.toArray(), QR.rows(), QR.columns(), qraux.toArray(), MatrixUtil.removeMissingOrInfinite( y ).toArray(),
246 junk, qty,
247 junk, junk, junk, 1000 );
248 return new DenseDoubleMatrix1D( qty );
249 }
250
251
252
253
254
255
256
257
258
259 public DoubleMatrix2D effects( DoubleMatrix2D y ) {
260 double[][] efa = new double[y.columns()][y.rows()];
261 for ( int i = 0; i < y.columns(); i++ ) {
262 efa[i] = effects( y.viewColumn( i ) ).toArray();
263 }
264 return new DenseDoubleMatrix2D( efa ).viewDice();
265 }
266
267
268
269
270 public IntArrayList getPivotOrder() {
271 return new IntArrayList( jpvt );
272 }
273
274
275
276
277
278
279 public DoubleMatrix2D getQ() {
280
281
282 if ( this.Qcached != null ) return Qcached;
283
284 DoubleMatrix2D Q = QR.like();
285
286 for ( int i = 0; i < Q.columns(); i++ ) {
287 Q.set( i, i, 1 );
288 }
289
290 for ( int jy = 0; jy < p; jy++ ) {
291
292 DoubleMatrix1D y = Q.viewColumn( jy );
293
294 for ( int jj = 1; jj <= p; jj++ ) {
295 int j = p - jj;
296
297 if ( qraux.get( j ) == 0.0 ) {
298 continue;
299 }
300
301 double temp = QR.get( j, j );
302 QR.set( j, j, qraux.get( j ) );
303 DoubleMatrix1D QRcolv = QR.viewColumn( j ).viewPart( j, n - j );
304 DoubleMatrix1D Qcolj = y.viewPart( j, n - j );
305 double s = QRcolv.zDotProduct( Qcolj );
306 s = -s / QR.getQuick( j, j );
307
308 Qcolj.assign( QRcolv, Functions.plusMult( s ) );
309
310 QR.set( j, j, temp );
311 }
312 }
313 this.Qcached = Q;
314 return Q;
315
316 }
317
318
319
320
321 public DoubleMatrix1D getQraux() {
322 return qraux;
323 }
324
325
326
327
328
329
330 public DoubleMatrix2D getR() {
331 DoubleMatrix2D R = QR.like( p, p );
332 for ( int i = 0; i < p; i++ ) {
333 for ( int j = 0; j < p; j++ ) {
334 if ( i < j )
335 R.setQuick( i, j, QR.getQuick( i, j ) );
336 else if ( i == j )
337 R.setQuick( i, j, Rdiag.getQuick( i ) );
338 else
339 R.setQuick( i, j, 0 );
340 }
341 }
342 return R;
343 }
344
345
346
347
348 public int getRank() {
349 return rank;
350 }
351
352 public double getTolerance() {
353 return tolerance;
354 }
355
356
357
358
359
360
361 public boolean hasFullRank() {
362 return rank == p;
363 }
364
365
366
367
368 public boolean isPivoting() {
369 return pivoting;
370 }
371
372
373
374
375
376
377
378
379
380
381
382
383 public DoubleMatrix2D solve( DoubleMatrix2D y ) {
384 if ( y.rows() != n ) {
385 throw new IllegalArgumentException( "Matrix row dimensions must agree." );
386 }
387
388 if ( !pivoting && !this.hasFullRank() ) {
389 throw new IllegalArgumentException( "Matrix is rank deficient; try using pivoting" );
390 }
391
392 DoubleMatrix2D qTy = effects( y );
393
394
395 for ( int k1 = rank - 1; k1 >= 0; k1-- ) {
396
397 for ( int j = 0; j < y.columns(); j++ ) {
398 qTy.setQuick( k1, j, qTy.getQuick( k1, j ) / Rdiag.getQuick( k1 ) );
399 }
400 for ( int i = 0; i < k1; i++ ) {
401
402 for ( int j = 0; j < y.columns(); j++ ) {
403 qTy.setQuick( i, j, qTy.getQuick( i, j ) - qTy.getQuick( k1, j ) * QR.getQuick( i, k1 ) );
404 }
405 }
406 }
407
408
409
410
411 if ( this.rank < this.p ) {
412 for ( int i = rank; i < this.p; i++ ) {
413 qTy.viewRow( i ).assign( Double.NaN );
414 }
415 }
416
417
418
419
420 DoubleMatrix2D coeff = qTy.like( p, y.columns() );
421 coeff.assign( Double.NaN );
422 for ( int i = 0; i < this.rank; i++ ) {
423 int piv = jpvt[i];
424 for ( int j = 0; j < qTy.columns(); j++ ) {
425 coeff.setQuick( piv, j, qTy.getQuick( i, j ) );
426 }
427 }
428
429 return coeff;
430
431 }
432
433
434
435
436
437 @Override
438 public String toString() {
439 StringBuilder buf = new StringBuilder();
440 String unknown = "Illegal operation or error: ";
441
442 buf.append( "-----------------------------------------------------------------\n" );
443 buf.append( "QRDecomposition(A) \n" );
444 buf.append( "-----------------------------------------------------------------\n" );
445
446 buf.append( "rank = " + rank );
447
448 buf.append( "\n\nQ = " );
449 try {
450 buf.append( String.valueOf( this.getQ() ) );
451 } catch ( IllegalArgumentException exc ) {
452 buf.append( unknown + exc.getMessage() );
453 }
454
455 buf.append( "\n\nR = " );
456 try {
457 buf.append( String.valueOf( this.getR() ) );
458 } catch ( IllegalArgumentException exc ) {
459 buf.append( unknown + exc.getMessage() );
460 }
461
462 buf.append( "\n\nQRaux = " + this.qraux );
463
464 return buf.toString();
465 }
466
467 protected String diagnose() {
468 StringBuilder buf = new StringBuilder();
469 buf.append( "Rank = " + rank + "\n" );
470
471 buf.append( "Qraux: " + qraux + "\n" );
472 buf.append( "Pivot indices: " + StringUtils.join( ArrayUtils.toObject( jpvt ), " " ) + "\n" );
473 return buf.toString();
474 }
475
476
477
478
479
480
481 protected DoubleMatrix2D getQR() {
482 return QR;
483 }
484
485
486
487
488
489
490
491 private DoubleMatrix2D dpotri( DoubleMatrix2D x ) {
492
493 if ( this.chol2inv != null ) return this.chol2inv;
494
495 DenseMatrix denseMatrix = new DenseMatrix( x.copy().toArray() );
496 intW status = new intW( 0 );
497 Dpotri.dpotri( "U", x.columns(), denseMatrix.getData(), 0, x.columns(), status );
498 if ( status.val != 0 ) {
499 throw new IllegalStateException( "Could not invert matrix" );
500 }
501
502 this.chol2inv = new DenseDoubleMatrix2D( Matrices.getArray( denseMatrix ) );
503 return this.chol2inv;
504 }
505
506
507
508
509 private double hypot( double a, double b ) {
510 double r;
511 if ( Math.abs( a ) > Math.abs( b ) ) {
512 r = b / a;
513 r = Math.abs( a ) * Math.sqrt( 1 + r * r );
514 } else if ( b != 0 ) {
515 r = a / b;
516 r = Math.abs( b ) * Math.sqrt( 1 + r * r );
517 } else {
518 r = 0.0;
519 }
520 return r;
521 }
522
523
524
525
526
527
528 private void rotate( DoubleMatrix2D x, DoubleMatrix1D work, int v ) {
529 for ( int i = 0; i < n; i++ ) {
530 double t = x.get( i, v );
531 for ( int j = v; j < p - 1; j++ ) {
532 x.set( i, j, x.get( i, j + 1 ) );
533 }
534 x.set( i, p - 1, t );
535 }
536
537
538 int i = jpvt[v];
539 double t = qraux.get( v );
540 double w0 = work.get( v );
541
542 for ( int j = v; j < p - 1; j++ ) {
543 jpvt[j] = jpvt[j + 1];
544 qraux.set( j, qraux.get( j + 1 ) );
545 work.set( j, work.get( j + 1 ) );
546 }
547 jpvt[p - 1] = i;
548 qraux.set( p - 1, t );
549 work.set( p - 1, w0 );
550 rank = rank - 1;
551 }
552
553 }