View Javadoc
1   /*
2    * The baseCode project
3    *
4    * Copyright (c) 2011 University of British Columbia
5    *
6    * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
7    * the License. You may obtain a copy of the License at
8    *
9    * http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
12   * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
13   * specific language governing permissions and limitations under the License.
14   */
15  package ubic.basecode.math.linalg;
16  
17  import cern.colt.list.IntArrayList;
18  import cern.colt.matrix.DoubleMatrix1D;
19  import cern.colt.matrix.DoubleMatrix2D;
20  import cern.colt.matrix.impl.DenseDoubleMatrix2D;
21  import cern.jet.math.Functions;
22  import no.uib.cipr.matrix.DenseMatrix;
23  import no.uib.cipr.matrix.Matrices;
24  import org.apache.commons.lang3.ArrayUtils;
25  import org.apache.commons.lang3.StringUtils;
26  import org.netlib.lapack.Dpotri;
27  import org.netlib.util.intW;
28  import org.slf4j.Logger;
29  import org.slf4j.LoggerFactory;
30  import ubic.basecode.dataStructure.matrix.DenseDoubleMatrix1D;
31  import ubic.basecode.dataStructure.matrix.MatrixUtil;
32  
33  /**
34   * QR with pivoting. See http://www.netlib.org/lapack/lug/node42.html and http://www.netlib.org/lapack/lug/node27.html,
35   * and Golub and VanLoan, section 5.5.6+. Designed to mimic the way R does this by default.
36   * 
37   * @author paul
38   */
39  public class QRDecomposition {
40  
41      private static Logger log = LoggerFactory.getLogger( QRDecomposition.class );
42  
43      private DenseDoubleMatrix2D chol2inv;
44  
45      private int[] jpvt;
46  
47      /**
48       * Rows in input
49       */
50      private int n;
51  
52      /**
53       * Columns in input
54       */
55      private int p;
56  
57      /**
58       * If pivoting was used.
59       */
60      private boolean pivoting = true;
61  
62      /**
63       * Contains the compact QR: R in the upper triangle, Q is recoverable from the lower part. FIXME if we have to
64       * access this as a raw array the DoubleMatrix2D API is inefficient.
65       */
66      private DoubleMatrix2D QR;
67  
68      /**
69       * Auxiliary information used to contsruct Q from the economy-sized QR.
70       */
71      private DoubleMatrix1D qraux;
72  
73      /**
74       * Rank
75       */
76      private int rank = 0;
77  
78      /**
79       * Diagonal of R
80       */
81      private DoubleMatrix1D Rdiag;
82  
83      /**
84       * Used to decide when to pivot
85       */
86      private double tolerance = 1e-7;
87  
88      /**
89       * 
90       */
91      private DoubleMatrix2D Qcached = null;
92  
93      private DoubleMatrix2D effects;
94  
95      /**
96       * @param A the matrix to decompose, pivoting will be used.
97       */
98      public QRDecomposition( final DoubleMatrix2D A ) {
99          this( A, true );
100     }
101 
102     /**
103      * Construct the QR decomposition of A. With pivoting = true, this reproduces quite closely the behaviour of R qr().
104      * 
105      * @param A the matrix to decompose
106      * @param pivoting set to false to obtain standard QR behaviour.
107      */
108     public QRDecomposition( final DoubleMatrix2D A, boolean pivoting ) {
109         // Initialize.
110         this.QR = A.copy();
111         this.n = A.rows();
112         this.p = A.columns();
113         this.Rdiag = A.like1D( p );
114         this.pivoting = pivoting;
115 
116         // initialization
117         jpvt = new int[p];
118         for ( int i = 0; i < p; i++ ) {
119             jpvt[i] = i;
120         }
121 
122         // initialization. We always compute qraux here.
123         DoubleMatrix1D originalNorms;
124         originalNorms = new DenseDoubleMatrix1D( p ); // "work" in linpack
125         qraux = new DenseDoubleMatrix1D( p );
126         for ( int i = 0; i < p; i++ ) {
127             DoubleMatrix1D col = QR.viewColumn( i );
128             double norm2 = 0;
129             for ( int r = 0; r < n; r++ ) {
130                 norm2 = hypot( norm2, col.getQuick( r ) );
131             }
132             qraux.set( i, norm2 );
133             originalNorms.set( i, norm2 == 0.0 ? 1.0 : norm2 );
134         }
135 
136         // precompute and cache some views to avoid regenerating them time and again
137         DoubleMatrix1D[] QRcolumns = new DoubleMatrix1D[p];
138         DoubleMatrix1D[] QRcolumnsPart = new DoubleMatrix1D[p];
139         for ( int v = 0; v < p; v++ ) {
140             QRcolumns[v] = QR.viewColumn( v );
141             QRcolumnsPart[v] = QR.viewColumn( v ).viewPart( v, n - v ); // upper triangle.
142         }
143 
144         rank = p;
145 
146         // Main loop.
147         for ( int v = 0; v < p; v++ ) {
148 
149             /*
150              * Rotate columns until we find one with a non-negligible norm. This is the pivoting strategy used in
151              * dqrdc2, which puts small columns to the right. See R documentation for qr and
152              * https://svn.r-project.org/R/trunk/src/appl/dqrdc2.f
153              */
154             while ( pivoting && v < rank && qraux.get( v ) < originalNorms.get( v ) * tolerance ) {
155                 log.debug( "Rotating " + v );
156                 rotate( QR, originalNorms, v );
157             }
158 
159             DoubleMatrix1D colv = QRcolumns[v];
160             double nrm = 0;
161             for ( int i = v; i < n; i++ ) {
162                 nrm = hypot( nrm, colv.getQuick( i ) );
163             }
164 
165             /*
166              * "Householder reflections can be used to calculate QR decompositions by reflecting first one column of a
167              * matrix onto a multiple of a standard basis vector, calculating the transformation matrix, multiplying it
168              * with the original matrix and then recursing down the (i, i) minors of that product."
169              */
170             if ( nrm != 0.0 ) {
171                 // Form k-th Householder vector: scale and flip
172                 if ( QR.getQuick( v, v ) < 0 ) nrm = -nrm; // dsign
173                 QRcolumnsPart[v].assign( Functions.div( nrm ) ); // dscal
174 
175                 QR.setQuick( v, v, QR.getQuick( v, v ) + 1.0 ); // update diagonal
176 
177                 // Apply transformation to remaining columns.
178                 for ( int j = v + 1; j < p; j++ ) {
179                     DoubleMatrix1D QRcolj = QR.viewColumn( j ).viewPart( v, n - v );
180                     double s = QRcolumnsPart[v].zDotProduct( QRcolj );
181 
182                     s = -s / QR.getQuick( v, v );
183                     for ( int i = v; i < n; i++ ) {
184                         QR.setQuick( i, j, QR.getQuick( i, j ) + s * QR.getQuick( i, v ) );
185                     }
186 
187                     if ( qraux.getQuick( j ) == 0 ) {
188                         continue;
189                     }
190 
191                     /*
192                      * Update the norm of this column. Used even if we are not pivoting.
193                      */
194                     double tt = QR.getQuick( v, j ) / qraux.getQuick( j );
195                     double t = Math.max( 1.0 - Math.pow( tt, 2 ), 0.0 );
196                     if ( t < 1e-6 ) {
197                         DoubleMatrix1D col = QR.viewColumn( j );
198                         double nrmv = 0.0;
199                         for ( int r = v + 1; r < n; r++ ) {
200                             nrmv = hypot( nrmv, col.getQuick( r ) );
201                         }
202                         qraux.set( j, nrmv );
203                     } else {
204                         qraux.set( j, qraux.getQuick( j ) * Math.sqrt( t ) );
205                     }
206                 }
207                 if ( log.isDebugEnabled() ) log.debug( qraux.toString() );
208 
209             }
210 
211             // save transformation parts we are done with.
212             qraux.set( v, QR.getQuick( v, v ) );
213             QR.setQuick( v, v, -nrm );
214             Rdiag.setQuick( v, -nrm );
215         }
216         rank = Math.min( rank, n );
217     }
218 
219     /**
220      * Used for computing standard errors of parameter estimates for least squares; copies functionality of R chol2inv.
221      * 
222      * @return
223      */
224     public DoubleMatrix2D chol2inv() {
225         // using "size" in dpotri doesn't work right.
226         return dpotri( this.getR().viewPart( 0, 0, this.rank, this.rank ) );
227     }
228 
229     /**
230      * Compute effects matrix Q'y (as in Rb = Q'y).
231      * 
232      * <p>
233      * "Tthe effects are the uncorrelated single-degree-of-freedom values obtained by projecting the data onto the
234      * successive orthogonal subspaces generated by the QR decomposition during the fitting process. The first r (the
235      * rank of the model) are associated with coefficients and the remainder span the space of residuals (but are not
236      * associated with particular residuals)."
237      * 
238      * @param y vector Missing values are ignored, otherwise assumed to be of the right size
239      * @return vector of effects - these are the projections of y into Q column space
240      */
241     public DoubleMatrix1D effects( DoubleMatrix1D y ) {
242 
243         double[] qty = new double[y.size()];
244         double[] junk = new double[y.size()];
245         ubic.basecode.math.linalg.Dqrsl.dqrsl_j( QR.toArray(), QR.rows(), QR.columns(), qraux.toArray(), MatrixUtil.removeMissingOrInfinite( y ).toArray(),
246                 junk, qty,
247                 junk, junk, junk, 1000 );
248         return new DenseDoubleMatrix1D( qty );
249     }
250 
251     /**
252      * Compute effects matrix Q'y (as in Rb = Q'y)
253      * 
254      * @param y matrix of data, assumed to be of right size, missing values are not supported
255      * @return matrix of effects - these are the projections of y's columns into Q column subspace associated with the
256      *         parameters,
257      *         so values are "effects" each basis vector on the data
258      */
259     public DoubleMatrix2D effects( DoubleMatrix2D y ) {
260         double[][] efa = new double[y.columns()][y.rows()];
261         for ( int i = 0; i < y.columns(); i++ ) {
262             efa[i] = effects( y.viewColumn( i ) ).toArray();
263         }
264         return new DenseDoubleMatrix2D( efa ).viewDice();
265     }
266 
267     /**
268      * @return
269      */
270     public IntArrayList getPivotOrder() {
271         return new IntArrayList( jpvt );
272     }
273 
274     /**
275      * Generates and returns the (economy-sized - first <tt>p</tt> columns only) orthogonal factor <tt>Q</tt>.
276      * 
277      * @return first <tt>p</tt> columns of <tt>Q</tt>
278      */
279     public DoubleMatrix2D getQ() {
280 
281         // For efficienty we do this... but really we should avoid directly getting Q.
282         if ( this.Qcached != null ) return Qcached;
283 
284         DoubleMatrix2D Q = QR.like();
285 
286         for ( int i = 0; i < Q.columns(); i++ ) {
287             Q.set( i, i, 1 );
288         }
289 
290         for ( int jy = 0; jy < p; jy++ ) {
291 
292             DoubleMatrix1D y = Q.viewColumn( jy );
293 
294             for ( int jj = 1; jj <= p; jj++ ) {
295                 int j = p - jj;
296 
297                 if ( qraux.get( j ) == 0.0 ) {
298                     continue;
299                 }
300 
301                 double temp = QR.get( j, j );
302                 QR.set( j, j, qraux.get( j ) );
303                 DoubleMatrix1D QRcolv = QR.viewColumn( j ).viewPart( j, n - j );
304                 DoubleMatrix1D Qcolj = y.viewPart( j, n - j );
305                 double s = QRcolv.zDotProduct( Qcolj );
306                 s = -s / QR.getQuick( j, j );
307 
308                 Qcolj.assign( QRcolv, Functions.plusMult( s ) );
309 
310                 QR.set( j, j, temp );
311             }
312         }
313         this.Qcached = Q;
314         return Q;
315 
316     }
317 
318     /**
319      * @return
320      */
321     public DoubleMatrix1D getQraux() {
322         return qraux;
323     }
324 
325     /**
326      * Returns the upper triangular factor, <tt>R</tt>.
327      * 
328      * @return <tt>R</tt>
329      */
330     public DoubleMatrix2D getR() {
331         DoubleMatrix2D R = QR.like( p, p );
332         for ( int i = 0; i < p; i++ ) {
333             for ( int j = 0; j < p; j++ ) {
334                 if ( i < j )
335                     R.setQuick( i, j, QR.getQuick( i, j ) );
336                 else if ( i == j )
337                     R.setQuick( i, j, Rdiag.getQuick( i ) );
338                 else
339                     R.setQuick( i, j, 0 );
340             }
341         }
342         return R;
343     }
344 
345     /**
346      * @return rank
347      */
348     public int getRank() {
349         return rank;
350     }
351 
352     public double getTolerance() {
353         return tolerance;
354     }
355 
356     /**
357      * Returns whether the matrix <tt>A</tt> has full rank.
358      * 
359      * @return true if <tt>R</tt>, and hence <tt>A</tt>, has full rank.
360      */
361     public boolean hasFullRank() {
362         return rank == p;
363     }
364 
365     /**
366      * @return true if pivoting was used (just whether it was set; not whether any actual pivoting happened)
367      */
368     public boolean isPivoting() {
369         return pivoting;
370     }
371 
372     /**
373      * Least squares solution of <tt>A*X = y</tt>; <tt>returns X</tt> using the stored QR decomposition of A.
374      * 
375      * @param y A matrix with as many rows as <tt>A</tt> and any number of columns. Least squares is fit for each column
376      *        of y.
377      * @return <tt>X</tt> that minimizes the two norm of <tt>Q*R*X - B</tt>.
378      * @exception IllegalArgumentException if <tt>y.rows() != A.rows()</tt>.
379      * @exception IllegalArgumentException if <tt>!this.hasFullRank()</tt> (<tt>A</tt> is rank deficient). However,
380      *            rank-deficient cases are handled by pivoting, so if you are using pivoting you should not see this
381      *            happening.
382      */
383     public DoubleMatrix2D solve( DoubleMatrix2D y ) {
384         if ( y.rows() != n ) {
385             throw new IllegalArgumentException( "Matrix row dimensions must agree." );
386         }
387 
388         if ( !pivoting && !this.hasFullRank() ) {
389             throw new IllegalArgumentException( "Matrix is rank deficient; try using pivoting" );
390         }
391 
392         DoubleMatrix2D qTy = effects( y ); // FIXME we use this again later, but we recompute it. Try to cache it.
393 
394         // Solve R*X = Y => X = RinvY; backsubstitution
395         for ( int k1 = rank - 1; k1 >= 0; k1-- ) {
396 
397             for ( int j = 0; j < y.columns(); j++ ) {
398                 qTy.setQuick( k1, j, qTy.getQuick( k1, j ) / Rdiag.getQuick( k1 ) );
399             }
400             for ( int i = 0; i < k1; i++ ) {
401                 // sum up to the parameter we've done.
402                 for ( int j = 0; j < y.columns(); j++ ) {
403                     qTy.setQuick( i, j, qTy.getQuick( i, j ) - qTy.getQuick( k1, j ) * QR.getQuick( i, k1 ) );
404                 }
405             }
406         }
407 
408         /*
409          * Drop coefficients we couldn't estimate. These will always be at the end, even if we pivoted ??????
410          */
411         if ( this.rank < this.p ) {
412             for ( int i = rank; i < this.p; i++ ) {
413                 qTy.viewRow( i ).assign( Double.NaN );
414             }
415         }
416 
417         /*
418          * Pad r1 back out to the full length p, and (if pivoted) in the right original order using jpvt
419          */
420         DoubleMatrix2D coeff = qTy.like( p, y.columns() );
421         coeff.assign( Double.NaN );
422         for ( int i = 0; i < this.rank; i++ ) {
423             int piv = jpvt[i]; // where the value should go.
424             for ( int j = 0; j < qTy.columns(); j++ ) {
425                 coeff.setQuick( piv, j, qTy.getQuick( i, j ) );
426             }
427         }
428 
429         return coeff;
430 
431     }
432 
433     /**
434      * Returns a String with (propertyName, propertyValue) pairs. Useful for debugging or to quickly get the rough
435      * picture.
436      */
437     @Override
438     public String toString() {
439         StringBuilder buf = new StringBuilder();
440         String unknown = "Illegal operation or error: ";
441 
442         buf.append( "-----------------------------------------------------------------\n" );
443         buf.append( "QRDecomposition(A) \n" );
444         buf.append( "-----------------------------------------------------------------\n" );
445 
446         buf.append( "rank = " + rank );
447 
448         buf.append( "\n\nQ = " );
449         try {
450             buf.append( String.valueOf( this.getQ() ) );
451         } catch ( IllegalArgumentException exc ) {
452             buf.append( unknown + exc.getMessage() );
453         }
454 
455         buf.append( "\n\nR = " );
456         try {
457             buf.append( String.valueOf( this.getR() ) );
458         } catch ( IllegalArgumentException exc ) {
459             buf.append( unknown + exc.getMessage() );
460         }
461 
462         buf.append( "\n\nQRaux = " + this.qraux );
463 
464         return buf.toString();
465     }
466 
467     protected String diagnose() {
468         StringBuilder buf = new StringBuilder();
469         buf.append( "Rank = " + rank + "\n" );
470         //  buf.append( "Work: " + originalNorms + "\n" );
471         buf.append( "Qraux: " + qraux + "\n" );
472         buf.append( "Pivot indices: " + StringUtils.join( ArrayUtils.toObject( jpvt ), "  " ) + "\n" );
473         return buf.toString();
474     }
475 
476     /**
477      * For testing.
478      * 
479      * @return
480      */
481     protected DoubleMatrix2D getQR() {
482         return QR;
483     }
484 
485     /**
486      * Mimics functionality of chol2inv from R (which just calls LAPACK::dpotri)
487      *
488      * @param x upper triangular matrix (from qr)
489      * @return symmetric matrix X'X^-1
490      */
491     private DoubleMatrix2D dpotri( DoubleMatrix2D x ) {
492 
493         if ( this.chol2inv != null ) return this.chol2inv;
494 
495         DenseMatrix denseMatrix = new DenseMatrix( x.copy().toArray() );
496         intW status = new intW( 0 );
497         Dpotri.dpotri( "U", x.columns(), denseMatrix.getData(), 0, x.columns(), status );
498         if ( status.val != 0 ) {
499             throw new IllegalStateException( "Could not invert matrix" );
500         }
501 
502         this.chol2inv = new DenseDoubleMatrix2D( Matrices.getArray( denseMatrix ) );
503         return this.chol2inv;
504     }
505 
506     /**
507      * Returns sqrt(a^2 + b^2) without under/overflow (from Colt)
508      */
509     private double hypot( double a, double b ) {
510         double r;
511         if ( Math.abs( a ) > Math.abs( b ) ) {
512             r = b / a;
513             r = Math.abs( a ) * Math.sqrt( 1 + r * r );
514         } else if ( b != 0 ) {
515             r = a / b;
516             r = Math.abs( b ) * Math.sqrt( 1 + r * r );
517         } else {
518             r = 0.0;
519         }
520         return r;
521     }
522 
523     /**
524      * @param x
525      * @param work
526      * @param v
527      */
528     private void rotate( DoubleMatrix2D x, DoubleMatrix1D work, int v ) {
529         for ( int i = 0; i < n; i++ ) {
530             double t = x.get( i, v );
531             for ( int j = v; j < p - 1; j++ ) {
532                 x.set( i, j, x.get( i, j + 1 ) );
533             }
534             x.set( i, p - 1, t );
535         }
536 
537         // do the same rotation to our helpers
538         int i = jpvt[v];
539         double t = qraux.get( v );
540         double w0 = work.get( v );
541 
542         for ( int j = v; j < p - 1; j++ ) {
543             jpvt[j] = jpvt[j + 1];
544             qraux.set( j, qraux.get( j + 1 ) );
545             work.set( j, work.get( j + 1 ) );
546         }
547         jpvt[p - 1] = i;
548         qraux.set( p - 1, t );
549         work.set( p - 1, w0 );
550         rank = rank - 1;
551     }
552 
553 }