View Javadoc
1   /*
2    * The baseCode project
3    * 
4    * Copyright (c) 2006 University of British Columbia
5    * 
6    * Licensed under the Apache License, Version 2.0 (the "License");
7    * you may not use this file except in compliance with the License.
8    * You may obtain a copy of the License at
9    *
10   *       http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   *
18   */
19  package ubic.basecode.math;
20  
21  import java.math.BigInteger;
22  import java.util.Collections;
23  import java.util.List;
24  import java.util.Map;
25  import java.util.concurrent.ConcurrentHashMap;
26  
27  import org.slf4j.Logger;
28  import org.slf4j.LoggerFactory;
29  
30  import cern.colt.list.DoubleArrayList;
31  import cern.jet.math.Arithmetic;
32  import cern.jet.stat.Probability;
33  
34  /**
35   * Implements methods from supplementary file I of "Comparing functional annotation analyses with Catmap", Thomas
36   * Breslin, Patrik Ed�n and Morten Krogh, BMC Bioinformatics 2004, 5:193 doi:10.1186/1471-2105-5-193
37   * <p>
38   * Note that in the Catmap code, zero-based ranks are used, but these are converted to one-based before computation of
39   * pvalues. Therefore this code uses one-based ranks throughout.
40   * 
41   * @author pavlidis
42   * @version Id
43   * @see ROC
44   */
45  public class Wilcoxon {
46  
47      private static final Map<CacheKey, BigInteger> cache = new ConcurrentHashMap<CacheKey, BigInteger>();
48  
49      /**
50       * For smaller sample sizes, we compute exactly. Below 1e5 we start to notice some loss of precision (like one part
51       * in 1e5). Setting this too high really slows things down for high-throughput applications.
52       */
53      private static long LIMIT_FOR_APPROXIMATION = 100000L;
54  
55      private static Logger log = LoggerFactory.getLogger( Wilcoxon.class );
56  
57      /**
58       * Convenience method that computes a p-value using input of two double arrays. They must not contain missing values
59       * or ties.
60       * 
61       * @param a
62       * @param b
63       * @return
64       */
65      public static double exactWilcoxonP( double[] a, double[] b ) {
66          int fullLength = a.length + b.length;
67          // need sum of A's ranks with respect to all
68          DoubleArrayList ad = new DoubleArrayList( a );
69          DoubleArrayList bb = new DoubleArrayList( b );
70          ad.addAllOf( bb );
71          DoubleArrayList abR = Rank.rankTransform( ad );
72          int aSum = 0;
73          for ( int i = 0; i < a.length; i++ ) {
74              aSum += abR.get( i );
75          }
76  
77          if ( aSum > LIMIT_FOR_APPROXIMATION ) {
78              throw new IllegalArgumentException( "Computation of exact wilcoxon for large values of rank sum will fail." );
79          }
80          return pExact( fullLength, a.length, aSum );
81      }
82  
83      /**
84       * @param N
85       * @param n
86       * @param R
87       * @return
88       */
89      public static double exactWilcoxonP( int N, int n, int R ) {
90          if ( R > LIMIT_FOR_APPROXIMATION ) {
91              throw new IllegalArgumentException( "Computation of exact wilcoxon for large values of R will fail." );
92          }
93          return pExact( N, n, R );
94      }
95  
96      /**
97       * Only use when you know there are no ties.
98       * 
99       * @param N
100      * @param n
101      * @param R
102      * @return
103      */
104     public static double wilcoxonP( int N, int n, long R ) {
105         return wilcoxonP( N, n, R, false );
106     }
107 
108     /**
109      * @param N number of all Items
110      * @param n number of class Items
111      * @param R rankSum for items in the class. (one-based)
112      * @param ties set to true if you know there are ties
113      * @return
114      */
115     public static double wilcoxonP( int N, int n, long R, boolean ties ) {
116 
117         if ( n > N ) throw new IllegalArgumentException( "n must be less than N (n=" + n + ", N=" + N + ")" );
118 
119         if ( n == 0 && N == 0 ) return 1.0;
120 
121         if ( ( !ties )
122                 && ( ( ( long ) N * ( long ) n <= LIMIT_FOR_APPROXIMATION && n * R <= LIMIT_FOR_APPROXIMATION && ( long ) N
123                         * ( long ) n * R <= LIMIT_FOR_APPROXIMATION ) || ( R < N && n * Math.pow( R, 2 ) <= LIMIT_FOR_APPROXIMATION ) ) ) {
124             if ( log.isDebugEnabled() ) log.debug( "Using exact method (" + N * n * R + ")" );
125             return pExact( N, n, ( int ) R );
126         }
127 
128         double p = pGaussian( N, n, R );
129 
130         if ( p < 0.1 && Math.pow( n, 2 ) * R / N <= 1e5 ) {
131             if ( log.isDebugEnabled() ) log.debug( "Using volume method (" + N * n * R + ")" );
132             return pVolume( N, n, R );
133         }
134 
135         if ( log.isDebugEnabled() ) log.debug( "Using gaussian method (" + N * n * R + ")" );
136         return p;
137     }
138 
139     /**
140      * @param N total number of items (in and not in the class)
141      * @param ranks of items in the class (one-based)
142      * @return
143      */
144     public static double wilcoxonP( int N, List<Double> ranks ) {
145 
146         /*
147          * Check for ties; cannot compute exact when there are ties.
148          */
149         Collections.sort( ranks );
150         Double p = null;
151         boolean ties = false;
152         for ( Double r : ranks ) {
153             if ( p != null ) {
154                 if ( r.equals( p ) ) {
155                     ties = true;
156                     break;
157                 }
158             }
159             p = r;
160         }
161 
162         long rankSum = Rank.rankSum( ranks );
163 
164         return wilcoxonP( N, ranks.size(), rankSum, ties );
165     }
166 
167     private static void addToCache( long N, long n, long R, BigInteger value ) {
168         cache.put( new CacheKey( N, n, R ), value );
169     }
170 
171     /**
172      * @param n0
173      * @param n02
174      * @param r0
175      * @return
176      */
177     private static boolean cacheContains( long N, long n, long R ) {
178         return cache.containsKey( new CacheKey( N, n, R ) );
179     }
180 
181     /**
182      * Direct port from catmap code. Exact computation of the number of ways n items can be drawn from a total of N
183      * items with a rank sum of R or better (lower).
184      * 
185      * @param N0
186      * @param n0
187      * @param R0 rank sum, 1-based (best rank is 1)
188      * @return
189      */
190     private static BigInteger computeA__( int N0, int n0, int R0 ) {
191         cache.clear();
192         if ( R0 < N0 ) N0 = R0;
193 
194         // if ( cacheContains( N0, n0, R0 ) ) {
195         // return getFromCache( N0, n0, R0 );
196         // }
197 
198         if ( N0 == 0 && n0 == 0 ) return BigInteger.ONE;
199 
200         for ( int N = 1; N <= N0; N++ ) {
201             if ( N > 2 ) removeFromCache( N - 2 );
202 
203             /* n has to be less than N */
204             long min_n = Math.max( 0, n0 + N - N0 );
205             long max_n = Math.min( n0, N );
206 
207             assert min_n >= 0;
208             assert max_n >= min_n;
209 
210             for ( long n = min_n; n <= max_n; n++ ) {
211 
212                 /* The rank sum is in the interval n(n+1)/2 to n(2N-n+1)/2. Other values need not be looked at. */
213                 long bestPossibleRankSum = n * ( n + 1 ) / 2;
214                 long worstPossibleRankSum = n * ( 2 * N - n + 1 ) / 2;
215 
216                 /* Ensure value looked at is valid for the original set of parameters. */
217                 long min_r = Math.max( bestPossibleRankSum, R0 - ( N0 + N + 1 ) * ( N0 - N ) / 2 );
218                 long max_r = Math.min( worstPossibleRankSum, R0 );
219 
220                 assert min_r >= 0;
221 
222                 if ( min_r > max_r ) {
223                     throw new IllegalStateException( min_r + " > " + max_r );
224                 }
225 
226                 assert max_r >= min_r : String.format( "max_r %d < min_r %d for N=%d, n=%d, r=%d", max_r, min_r, N0,
227                         n0, R0 );
228 
229                 /* R greater than this, have already computed it in parts */
230                 long foo = n * ( 2 * N - n - 1 ) / 2;
231 
232                 /* R less than this, we have already computed it in parts */
233                 long bar = N + ( n - 1 ) * n / 2;
234 
235                 for ( long r = min_r; r <= max_r; r++ ) {
236 
237                     if ( n == 0 || n == N || r == bestPossibleRankSum ) {
238                         addToCache( N, n, r, BigInteger.ONE );
239 
240                     } else if ( r > foo ) {
241                         addToCache( N, n, r, getFromCache( N - 1, n, foo ).add( getFromCache( N - 1, n - 1, r - N ) ) );
242 
243                     } else if ( r < bar ) {
244                         addToCache( N, n, r, getFromCache( N - 1, n, r ) );
245 
246                     } else {
247                         addToCache( N, n, r, getFromCache( N - 1, n, r ).add( getFromCache( N - 1, n - 1, r - N ) ) );
248                     }
249                 }
250             }
251         }
252         return getFromCache( N0, n0, R0 );
253     }
254 
255     /**
256      * @param N
257      * @param n
258      * @param R
259      * @return
260      */
261     private static BigInteger getFromCache( long N, long n, long R ) {
262 
263         if ( !cacheContains( N, n, R ) ) {
264             throw new IllegalStateException( "No value stored for N=" + N + ", n=" + n + ", R=" + R );
265         }
266         return cache.get( new CacheKey( N, n, R ) );
267     }
268 
269     /**
270      * @param N
271      * @param n
272      * @param r rank sum, 1-based (best rank is 1).
273      * @return
274      */
275     private static double pExact( int N, int n, int R ) {
276         return computeA__( N, n, R ).doubleValue() / Arithmetic.binomial( N, n );
277     }
278 
279     /**
280      * @param N
281      * @param n
282      * @param R
283      * @return Upper-tail probability for Wilcoxon rank-sum test.
284      */
285     private static double pGaussian( long N, long n, long R ) {
286         if ( n > N ) throw new IllegalArgumentException( "n must be smaller than N" );
287         double mean = n * ( N + 1 ) / 2.0;
288         double var = n * ( N - n ) * ( N + 1 ) / 12.0;
289         if ( log.isDebugEnabled() ) log.debug( "Mean=" + mean + " Var=" + var + " R=" + R );
290         return Probability.normal( 0.0, var, R - mean );
291     }
292 
293     /**
294      * Directly ported from catmap.
295      * 
296      * @param N
297      * @param n
298      * @param R
299      * @return
300      */
301     private static double pVolume( int N, int n, long R ) {
302 
303         double t = R / ( double ) N;
304 
305         if ( t < 0 ) return 0.0;
306         if ( t >= n ) return 1.0;
307         double[] logFactors = new double[n + 1];
308         logFactors[0] = 0.0;
309         logFactors[1] = 0.0;
310         for ( int i = 2; i <= n; i++ ) {
311             logFactors[i] = logFactors[i - 1] + Math.log( i );
312         }
313 
314         int kMax = ( int ) t;
315         double[][] C = new double[n][];
316         for ( int i = 0; i < n; i++ ) {
317             C[i] = new double[n + 1];
318             C[0][i] = 0.0;
319         }
320 
321         C[0][n] = Math.exp( -logFactors[n] );
322         for ( int k = 1; k <= kMax; k++ ) {
323             for ( int a = 0; a < n; a++ ) {
324                 for ( int j = a; j <= n; j++ ) {
325                     C[k][a] += C[k - 1][j] * Math.exp( logFactors[j] - logFactors[a] - logFactors[j - a] );
326                 }
327             }
328             double b = Math.exp( -logFactors[k] - logFactors[n - 1 - k] ) / n;
329             C[k][n] = k % 2 != 0 ? -b : b;
330         }
331 
332         double result = 0.0;
333         for ( int a = 0; a <= n; a++ ) {
334             result += C[kMax][a] * Math.pow( t - kMax, a );
335         }
336         return result;
337 
338     }
339 
340     /**
341      * @param i
342      */
343     private static void removeFromCache( int N ) {
344         cache.remove( N );
345     }
346 
347 }
348 
349 class CacheKey {
350     private long n;
351     private long N;
352     private long R;
353 
354     public CacheKey( long N, long n, long R ) {
355         super();
356         this.N = N;
357         this.n = n;
358         this.R = R;
359     }
360 
361     @Override
362     public boolean equals( Object obj ) {
363         CacheKey other = ( CacheKey ) obj;
364 
365         if ( N != other.N ) return false;
366         if ( n != other.n ) return false;
367         if ( R != other.R ) return false;
368 
369         return true;
370     }
371 
372     @Override
373     public int hashCode() {
374         final int prime = 31;
375         long result = 1;
376         result = prime * result + N;
377         result = prime * result + n;
378         result = prime * result + R;
379         return ( int ) result; // problem: overflows?
380     }
381 
382 }