1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  package ubic.basecode.math.linalg;
16  
17  import cern.colt.list.IntArrayList;
18  import cern.colt.matrix.DoubleMatrix1D;
19  import cern.colt.matrix.DoubleMatrix2D;
20  import cern.colt.matrix.impl.DenseDoubleMatrix2D;
21  import cern.jet.math.Functions;
22  import no.uib.cipr.matrix.DenseMatrix;
23  import no.uib.cipr.matrix.Matrices;
24  import org.apache.commons.lang3.ArrayUtils;
25  import org.apache.commons.lang3.StringUtils;
26  import org.netlib.lapack.Dpotri;
27  import org.netlib.util.intW;
28  import org.slf4j.Logger;
29  import org.slf4j.LoggerFactory;
30  import ubic.basecode.dataStructure.matrix.DenseDoubleMatrix1D;
31  import ubic.basecode.dataStructure.matrix.MatrixUtil;
32  
33  
34  
35  
36  
37  
38  
39  public class QRDecomposition {
40  
41      private static Logger log = LoggerFactory.getLogger( QRDecomposition.class );
42  
43      private DenseDoubleMatrix2D chol2inv;
44  
45      private int[] jpvt;
46  
47      
48  
49  
50      private int n;
51  
52      
53  
54  
55      private int p;
56  
57      
58  
59  
60      private boolean pivoting = true;
61  
62      
63  
64  
65  
66      private DoubleMatrix2D QR;
67  
68      
69  
70  
71      private DoubleMatrix1D qraux;
72  
73      
74  
75  
76      private int rank = 0;
77  
78      
79  
80  
81      private DoubleMatrix1D Rdiag;
82  
83      
84  
85  
86      private double tolerance = 1e-7;
87  
88      
89  
90  
91      private DoubleMatrix2D Qcached = null;
92  
93      private DoubleMatrix2D effects;
94  
95      
96  
97  
98      public QRDecomposition( final DoubleMatrix2D A ) {
99          this( A, true );
100     }
101 
102     
103 
104 
105 
106 
107 
108     public QRDecomposition( final DoubleMatrix2D A, boolean pivoting ) {
109         
110         this.QR = A.copy();
111         this.n = A.rows();
112         this.p = A.columns();
113         this.Rdiag = A.like1D( p );
114         this.pivoting = pivoting;
115 
116         
117         jpvt = new int[p];
118         for ( int i = 0; i < p; i++ ) {
119             jpvt[i] = i;
120         }
121 
122         
123         DoubleMatrix1D originalNorms;
124         originalNorms = new DenseDoubleMatrix1D( p ); 
125         qraux = new DenseDoubleMatrix1D( p );
126         for ( int i = 0; i < p; i++ ) {
127             DoubleMatrix1D col = QR.viewColumn( i );
128             double norm2 = 0;
129             for ( int r = 0; r < n; r++ ) {
130                 norm2 = hypot( norm2, col.getQuick( r ) );
131             }
132             qraux.set( i, norm2 );
133             originalNorms.set( i, norm2 == 0.0 ? 1.0 : norm2 );
134         }
135 
136         
137         DoubleMatrix1D[] QRcolumns = new DoubleMatrix1D[p];
138         DoubleMatrix1D[] QRcolumnsPart = new DoubleMatrix1D[p];
139         for ( int v = 0; v < p; v++ ) {
140             QRcolumns[v] = QR.viewColumn( v );
141             QRcolumnsPart[v] = QR.viewColumn( v ).viewPart( v, n - v ); 
142         }
143 
144         rank = p;
145 
146         
147         for ( int v = 0; v < p; v++ ) {
148 
149             
150 
151 
152 
153 
154             while ( pivoting && v < rank && qraux.get( v ) < originalNorms.get( v ) * tolerance ) {
155                 log.debug( "Rotating " + v );
156                 rotate( QR, originalNorms, v );
157             }
158 
159             DoubleMatrix1D colv = QRcolumns[v];
160             double nrm = 0;
161             for ( int i = v; i < n; i++ ) {
162                 nrm = hypot( nrm, colv.getQuick( i ) );
163             }
164 
165             
166 
167 
168 
169 
170             if ( nrm != 0.0 ) {
171                 
172                 if ( QR.getQuick( v, v ) < 0 ) nrm = -nrm; 
173                 QRcolumnsPart[v].assign( Functions.div( nrm ) ); 
174 
175                 QR.setQuick( v, v, QR.getQuick( v, v ) + 1.0 ); 
176 
177                 
178                 for ( int j = v + 1; j < p; j++ ) {
179                     DoubleMatrix1D QRcolj = QR.viewColumn( j ).viewPart( v, n - v );
180                     double s = QRcolumnsPart[v].zDotProduct( QRcolj );
181 
182                     s = -s / QR.getQuick( v, v );
183                     for ( int i = v; i < n; i++ ) {
184                         QR.setQuick( i, j, QR.getQuick( i, j ) + s * QR.getQuick( i, v ) );
185                     }
186 
187                     if ( qraux.getQuick( j ) == 0 ) {
188                         continue;
189                     }
190 
191                     
192 
193 
194                     double tt = QR.getQuick( v, j ) / qraux.getQuick( j );
195                     double t = Math.max( 1.0 - Math.pow( tt, 2 ), 0.0 );
196                     if ( t < 1e-6 ) {
197                         DoubleMatrix1D col = QR.viewColumn( j );
198                         double nrmv = 0.0;
199                         for ( int r = v + 1; r < n; r++ ) {
200                             nrmv = hypot( nrmv, col.getQuick( r ) );
201                         }
202                         qraux.set( j, nrmv );
203                     } else {
204                         qraux.set( j, qraux.getQuick( j ) * Math.sqrt( t ) );
205                     }
206                 }
207                 if ( log.isDebugEnabled() ) log.debug( qraux.toString() );
208 
209             }
210 
211             
212             qraux.set( v, QR.getQuick( v, v ) );
213             QR.setQuick( v, v, -nrm );
214             Rdiag.setQuick( v, -nrm );
215         }
216         rank = Math.min( rank, n );
217     }
218 
219     
220 
221 
222 
223 
224     public DoubleMatrix2D chol2inv() {
225         
226         return dpotri( this.getR().viewPart( 0, 0, this.rank, this.rank ) );
227     }
228 
229     
230 
231 
232 
233 
234 
235 
236 
237 
238 
239 
240 
241     public DoubleMatrix1D effects( DoubleMatrix1D y ) {
242 
243         double[] qty = new double[y.size()];
244         double[] junk = new double[y.size()];
245         ubic.basecode.math.linalg.Dqrsl.dqrsl_j( QR.toArray(), QR.rows(), QR.columns(), qraux.toArray(), MatrixUtil.removeMissingOrInfinite( y ).toArray(),
246                 junk, qty,
247                 junk, junk, junk, 1000 );
248         return new DenseDoubleMatrix1D( qty );
249     }
250 
251     
252 
253 
254 
255 
256 
257 
258 
259     public DoubleMatrix2D effects( DoubleMatrix2D y ) {
260         double[][] efa = new double[y.columns()][y.rows()];
261         for ( int i = 0; i < y.columns(); i++ ) {
262             efa[i] = effects( y.viewColumn( i ) ).toArray();
263         }
264         return new DenseDoubleMatrix2D( efa ).viewDice();
265     }
266 
267     
268 
269 
270     public IntArrayList getPivotOrder() {
271         return new IntArrayList( jpvt );
272     }
273 
274     
275 
276 
277 
278 
279     public DoubleMatrix2D getQ() {
280 
281         
282         if ( this.Qcached != null ) return Qcached;
283 
284         DoubleMatrix2D Q = QR.like();
285 
286         for ( int i = 0; i < Q.columns(); i++ ) {
287             Q.set( i, i, 1 );
288         }
289 
290         for ( int jy = 0; jy < p; jy++ ) {
291 
292             DoubleMatrix1D y = Q.viewColumn( jy );
293 
294             for ( int jj = 1; jj <= p; jj++ ) {
295                 int j = p - jj;
296 
297                 if ( qraux.get( j ) == 0.0 ) {
298                     continue;
299                 }
300 
301                 double temp = QR.get( j, j );
302                 QR.set( j, j, qraux.get( j ) );
303                 DoubleMatrix1D QRcolv = QR.viewColumn( j ).viewPart( j, n - j );
304                 DoubleMatrix1D Qcolj = y.viewPart( j, n - j );
305                 double s = QRcolv.zDotProduct( Qcolj );
306                 s = -s / QR.getQuick( j, j );
307 
308                 Qcolj.assign( QRcolv, Functions.plusMult( s ) );
309 
310                 QR.set( j, j, temp );
311             }
312         }
313         this.Qcached = Q;
314         return Q;
315 
316     }
317 
318     
319 
320 
321     public DoubleMatrix1D getQraux() {
322         return qraux;
323     }
324 
325     
326 
327 
328 
329 
330     public DoubleMatrix2D getR() {
331         DoubleMatrix2D R = QR.like( p, p );
332         for ( int i = 0; i < p; i++ ) {
333             for ( int j = 0; j < p; j++ ) {
334                 if ( i < j )
335                     R.setQuick( i, j, QR.getQuick( i, j ) );
336                 else if ( i == j )
337                     R.setQuick( i, j, Rdiag.getQuick( i ) );
338                 else
339                     R.setQuick( i, j, 0 );
340             }
341         }
342         return R;
343     }
344 
345     
346 
347 
348     public int getRank() {
349         return rank;
350     }
351 
352     public double getTolerance() {
353         return tolerance;
354     }
355 
356     
357 
358 
359 
360 
361     public boolean hasFullRank() {
362         return rank == p;
363     }
364 
365     
366 
367 
368     public boolean isPivoting() {
369         return pivoting;
370     }
371 
372     
373 
374 
375 
376 
377 
378 
379 
380 
381 
382 
383     public DoubleMatrix2D solve( DoubleMatrix2D y ) {
384         if ( y.rows() != n ) {
385             throw new IllegalArgumentException( "Matrix row dimensions must agree." );
386         }
387 
388         if ( !pivoting && !this.hasFullRank() ) {
389             throw new IllegalArgumentException( "Matrix is rank deficient; try using pivoting" );
390         }
391 
392         DoubleMatrix2D qTy = effects( y ); 
393 
394         
395         for ( int k1 = rank - 1; k1 >= 0; k1-- ) {
396 
397             for ( int j = 0; j < y.columns(); j++ ) {
398                 qTy.setQuick( k1, j, qTy.getQuick( k1, j ) / Rdiag.getQuick( k1 ) );
399             }
400             for ( int i = 0; i < k1; i++ ) {
401                 
402                 for ( int j = 0; j < y.columns(); j++ ) {
403                     qTy.setQuick( i, j, qTy.getQuick( i, j ) - qTy.getQuick( k1, j ) * QR.getQuick( i, k1 ) );
404                 }
405             }
406         }
407 
408         
409 
410 
411         if ( this.rank < this.p ) {
412             for ( int i = rank; i < this.p; i++ ) {
413                 qTy.viewRow( i ).assign( Double.NaN );
414             }
415         }
416 
417         
418 
419 
420         DoubleMatrix2D coeff = qTy.like( p, y.columns() );
421         coeff.assign( Double.NaN );
422         for ( int i = 0; i < this.rank; i++ ) {
423             int piv = jpvt[i]; 
424             for ( int j = 0; j < qTy.columns(); j++ ) {
425                 coeff.setQuick( piv, j, qTy.getQuick( i, j ) );
426             }
427         }
428 
429         return coeff;
430 
431     }
432 
433     
434 
435 
436 
437     @Override
438     public String toString() {
439         StringBuilder buf = new StringBuilder();
440         String unknown = "Illegal operation or error: ";
441 
442         buf.append( "-----------------------------------------------------------------\n" );
443         buf.append( "QRDecomposition(A) \n" );
444         buf.append( "-----------------------------------------------------------------\n" );
445 
446         buf.append( "rank = " + rank );
447 
448         buf.append( "\n\nQ = " );
449         try {
450             buf.append( String.valueOf( this.getQ() ) );
451         } catch ( IllegalArgumentException exc ) {
452             buf.append( unknown + exc.getMessage() );
453         }
454 
455         buf.append( "\n\nR = " );
456         try {
457             buf.append( String.valueOf( this.getR() ) );
458         } catch ( IllegalArgumentException exc ) {
459             buf.append( unknown + exc.getMessage() );
460         }
461 
462         buf.append( "\n\nQRaux = " + this.qraux );
463 
464         return buf.toString();
465     }
466 
467     protected String diagnose() {
468         StringBuilder buf = new StringBuilder();
469         buf.append( "Rank = " + rank + "\n" );
470         
471         buf.append( "Qraux: " + qraux + "\n" );
472         buf.append( "Pivot indices: " + StringUtils.join( ArrayUtils.toObject( jpvt ), "  " ) + "\n" );
473         return buf.toString();
474     }
475 
476     
477 
478 
479 
480 
481     protected DoubleMatrix2D getQR() {
482         return QR;
483     }
484 
485     
486 
487 
488 
489 
490 
491     private DoubleMatrix2D dpotri( DoubleMatrix2D x ) {
492 
493         if ( this.chol2inv != null ) return this.chol2inv;
494 
495         DenseMatrix denseMatrix = new DenseMatrix( x.copy().toArray() );
496         intW status = new intW( 0 );
497         Dpotri.dpotri( "U", x.columns(), denseMatrix.getData(), 0, x.columns(), status );
498         if ( status.val != 0 ) {
499             throw new IllegalStateException( "Could not invert matrix" );
500         }
501 
502         this.chol2inv = new DenseDoubleMatrix2D( Matrices.getArray( denseMatrix ) );
503         return this.chol2inv;
504     }
505 
506     
507 
508 
509     private double hypot( double a, double b ) {
510         double r;
511         if ( Math.abs( a ) > Math.abs( b ) ) {
512             r = b / a;
513             r = Math.abs( a ) * Math.sqrt( 1 + r * r );
514         } else if ( b != 0 ) {
515             r = a / b;
516             r = Math.abs( b ) * Math.sqrt( 1 + r * r );
517         } else {
518             r = 0.0;
519         }
520         return r;
521     }
522 
523     
524 
525 
526 
527 
528     private void rotate( DoubleMatrix2D x, DoubleMatrix1D work, int v ) {
529         for ( int i = 0; i < n; i++ ) {
530             double t = x.get( i, v );
531             for ( int j = v; j < p - 1; j++ ) {
532                 x.set( i, j, x.get( i, j + 1 ) );
533             }
534             x.set( i, p - 1, t );
535         }
536 
537         
538         int i = jpvt[v];
539         double t = qraux.get( v );
540         double w0 = work.get( v );
541 
542         for ( int j = v; j < p - 1; j++ ) {
543             jpvt[j] = jpvt[j + 1];
544             qraux.set( j, qraux.get( j + 1 ) );
545             work.set( j, work.get( j + 1 ) );
546         }
547         jpvt[p - 1] = i;
548         qraux.set( p - 1, t );
549         work.set( p - 1, w0 );
550         rank = rank - 1;
551     }
552 
553 }