View Javadoc
1   /*
2    * The baseCode project
3    *
4    * Copyright (c) 2006 University of British Columbia
5    *
6    * Licensed under the Apache License, Version 2.0 (the "License");
7    * you may not use this file except in compliance with the License.
8    * You may obtain a copy of the License at
9    *
10   *       http://www.apache.org/licenses/LICENSE-2.0
11   *
12   * Unless required by applicable law or agreed to in writing, software
13   * distributed under the License is distributed on an "AS IS" BASIS,
14   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15   * See the License for the specific language governing permissions and
16   * limitations under the License.
17   *
18   */
19  package ubic.basecode.math;
20  
21  import cern.colt.list.DoubleArrayList;
22  import cern.jet.math.Arithmetic;
23  import cern.jet.stat.Probability;
24  import org.slf4j.Logger;
25  import org.slf4j.LoggerFactory;
26  
27  import java.math.BigInteger;
28  import java.util.Collections;
29  import java.util.HashMap;
30  import java.util.List;
31  import java.util.Map;
32  
33  /**
34   * Implements methods from supplementary file I of "Comparing functional annotation analyses with Catmap", Thomas
35   * Breslin, Patrik Ed�n and Morten Krogh, BMC Bioinformatics 2004, 5:193 doi:10.1186/1471-2105-5-193
36   * <p>
37   * Note that in the Catmap code, zero-based ranks are used, but these are converted to one-based before computation of
38   * pvalues. Therefore this code uses one-based ranks throughout.
39   *
40   * @author pavlidis
41   * @version Id
42   * @see ROC
43   */
44  public class Wilcoxon {
45  
46  
47      /**
48       * For smaller sample sizes, we compute exactly. Below 1e5 we start to notice some loss of precision (like one part
49       * in 1e5). Setting this too high really slows things down for high-throughput applications.
50       */
51      private static final long LIMIT_FOR_APPROXIMATION = 100000L;
52  
53      private static final Logger log = LoggerFactory.getLogger( Wilcoxon.class );
54  
55      /**
56       * Convenience method that computes a p-value using input of two double arrays. They must not contain missing values
57       * or ties.
58       */
59      public static double exactWilcoxonP( double[] a, double[] b ) {
60          int fullLength = a.length + b.length;
61          // need sum of A's ranks with respect to all
62          DoubleArrayList ad = new DoubleArrayList( a );
63          DoubleArrayList bb = new DoubleArrayList( b );
64          ad.addAllOf( bb );
65          DoubleArrayList abR = Rank.rankTransform( ad );
66          int aSum = 0;
67          for ( int i = 0; i < a.length; i++ ) {
68              aSum += abR.get( i );
69          }
70  
71          if ( aSum > LIMIT_FOR_APPROXIMATION ) {
72              throw new IllegalArgumentException( "Computation of exact wilcoxon for large values of rank sum will fail." );
73          }
74          return pExact( fullLength, a.length, aSum );
75      }
76  
77      public static double exactWilcoxonP( int N, int n, int R ) {
78          if ( R > LIMIT_FOR_APPROXIMATION ) {
79              throw new IllegalArgumentException( "Computation of exact wilcoxon for large values of R will fail." );
80          }
81          return pExact( N, n, R );
82      }
83  
84      /**
85       * Only use when you know there are no ties.
86       */
87      public static double wilcoxonP( int N, int n, long R ) {
88          return wilcoxonP( N, n, R, false );
89      }
90  
91      /**
92       * @param N    number of all Items
93       * @param n    number of class Items
94       * @param R    rankSum for items in the class. (one-based)
95       * @param ties set to true if you know there are ties
96       */
97      public static double wilcoxonP( int N, int n, long R, boolean ties ) {
98  
99          if ( n > N ) throw new IllegalArgumentException( "n must be less than N (n=" + n + ", N=" + N + ")" );
100 
101         if ( n == 0 && N == 0 ) return 1.0;
102 
103         if ( ( !ties )
104                 && ( ( ( long ) N * ( long ) n <= LIMIT_FOR_APPROXIMATION && n * R <= LIMIT_FOR_APPROXIMATION && ( long ) N
105                 * ( long ) n * R <= LIMIT_FOR_APPROXIMATION ) || ( R < N && n * Math.pow( R, 2 ) <= LIMIT_FOR_APPROXIMATION ) ) ) {
106             if ( log.isDebugEnabled() ) log.debug( "Using exact method (" + N * n * R + ")" );
107             return pExact( N, n, ( int ) R );
108         }
109 
110         double p = pGaussian( N, n, R );
111 
112         if ( p < 0.1 && Math.pow( n, 2 ) * R / N <= 1e5 ) {
113             if ( log.isDebugEnabled() ) log.debug( "Using volume method (" + N * n * R + ")" );
114             return pVolume( N, n, R );
115         }
116 
117         if ( log.isDebugEnabled() ) log.debug( "Using gaussian method (" + N * n * R + ")" );
118         return p;
119     }
120 
121     /**
122      * @param N     total number of items (in and not in the class)
123      * @param ranks of items in the class (one-based)
124      */
125     public static double wilcoxonP( int N, List<Double> ranks ) {
126 
127         /*
128          * Check for ties; cannot compute exact when there are ties.
129          */
130         Collections.sort( ranks );
131         Double p = null;
132         boolean ties = false;
133         for ( Double r : ranks ) {
134             if ( p != null ) {
135                 if ( r.equals( p ) ) {
136                     ties = true;
137                     break;
138                 }
139             }
140             p = r;
141         }
142 
143         long rankSum = Rank.rankSum( ranks );
144 
145         return wilcoxonP( N, ranks.size(), rankSum, ties );
146     }
147 
148     /**
149      * Direct port from catmap code. Exact computation of the number of ways n items can be drawn from a total of N
150      * items with a rank sum of R or better (lower).
151      *
152      * @param R0 rank sum, 1-based (best rank is 1)
153      */
154     private static BigInteger computeA__( int N0, int n0, int R0 ) {
155         Map<CacheKey, BigInteger> cache = new HashMap<>();
156 
157         if ( R0 < N0 ) N0 = R0;
158 
159         if ( N0 == 0 && n0 == 0 ) return BigInteger.ONE;
160 
161         for ( int N = 1; N <= N0; N++ ) {
162             /* n has to be less than N */
163             long min_n = Math.max( 0, n0 + N - N0 );
164             long max_n = Math.min( n0, N );
165 
166             assert max_n >= min_n;
167 
168             for ( long n = min_n; n <= max_n; n++ ) {
169 
170                 /* The rank sum is in the interval n(n+1)/2 to n(2N-n+1)/2. Other values need not be looked at. */
171                 long bestPossibleRankSum = n * ( n + 1 ) / 2;
172                 long worstPossibleRankSum = n * ( 2L * N - n + 1 ) / 2;
173 
174                 /* Ensure value looked at is valid for the original set of parameters. */
175                 long min_r = Math.max( bestPossibleRankSum, R0 - ( N0 + N + 1L ) * ( N0 - N ) / 2 );
176                 long max_r = Math.min( worstPossibleRankSum, R0 );
177 
178                 assert min_r >= 0;
179 
180                 if ( min_r > max_r ) {
181                     throw new IllegalStateException( min_r + " > " + max_r );
182                 }
183 
184                 assert max_r >= min_r : String.format( "max_r %d < min_r %d for N=%d, n=%d, r=%d", max_r, min_r, N0,
185                         n0, R0 );
186 
187                 /* R greater than this, have already computed it in parts */
188                 long foo = n * ( 2L * N - n - 1 ) / 2;
189 
190                 /* R less than this, we have already computed it in parts */
191                 long bar = N + ( n - 1 ) * n / 2;
192 
193                 for ( long r = min_r; r <= max_r; r++ ) {
194 
195                     if ( n == 0 || n == N || r == bestPossibleRankSum ) {
196                         addToCache( cache, N, n, r, BigInteger.ONE );
197 
198                     } else if ( r > foo ) {
199                         addToCache( cache, N, n, r, getFromCache( cache, N - 1, n, foo ).add( getFromCache( cache, N - 1, n - 1, r - N ) ) );
200 
201                     } else if ( r < bar ) {
202                         addToCache( cache, N, n, r, getFromCache( cache, N - 1, n, r ) );
203 
204                     } else {
205                         addToCache( cache, N, n, r, getFromCache( cache, N - 1, n, r ).add( getFromCache( cache, N - 1, n - 1, r - N ) ) );
206                     }
207                 }
208             }
209         }
210 
211         return getFromCache( cache, N0, n0, R0 );
212     }
213 
214     /**
215      * @param R rank sum, 1-based (best rank is 1).
216      */
217     private static double pExact( int N, int n, int R ) {
218         return computeA__( N, n, R ).doubleValue() / Arithmetic.binomial( N, n );
219     }
220 
221     /**
222      * @return Upper-tail probability for Wilcoxon rank-sum test.
223      */
224     private static double pGaussian( long N, long n, long R ) {
225         if ( n > N ) throw new IllegalArgumentException( "n must be smaller than N" );
226         double mean = n * ( N + 1 ) / 2.0;
227         double var = n * ( N - n ) * ( N + 1 ) / 12.0;
228         if ( log.isDebugEnabled() ) log.debug( "Mean=" + mean + " Var=" + var + " R=" + R );
229         return Probability.normal( 0.0, var, R - mean );
230     }
231 
232     /**
233      * Directly ported from catmap.
234      */
235     private static double pVolume( int N, int n, long R ) {
236 
237         double t = R / ( double ) N;
238 
239         if ( t < 0 ) return 0.0;
240         if ( t >= n ) return 1.0;
241         double[] logFactors = new double[n + 1];
242         logFactors[0] = 0.0;
243         logFactors[1] = 0.0;
244         for ( int i = 2; i <= n; i++ ) {
245             logFactors[i] = logFactors[i - 1] + Math.log( i );
246         }
247 
248         int kMax = ( int ) t;
249         double[][] C = new double[n][];
250         for ( int i = 0; i < n; i++ ) {
251             C[i] = new double[n + 1];
252             C[0][i] = 0.0;
253         }
254 
255         C[0][n] = Math.exp( -logFactors[n] );
256         for ( int k = 1; k <= kMax; k++ ) {
257             for ( int a = 0; a < n; a++ ) {
258                 for ( int j = a; j <= n; j++ ) {
259                     C[k][a] += C[k - 1][j] * Math.exp( logFactors[j] - logFactors[a] - logFactors[j - a] );
260                 }
261             }
262             double b = Math.exp( -logFactors[k] - logFactors[n - 1 - k] ) / n;
263             C[k][n] = k % 2 != 0 ? -b : b;
264         }
265 
266         double result = 0.0;
267         for ( int a = 0; a <= n; a++ ) {
268             result += C[kMax][a] * Math.pow( t - kMax, a );
269         }
270         return result;
271 
272     }
273 
274     private static void addToCache( Map<CacheKey, BigInteger> cache, long N, long n, long R, BigInteger value ) {
275         cache.put( new CacheKey( N, n, R ), value );
276     }
277 
278     private static boolean cacheContains( Map<CacheKey, BigInteger> cache, long N, long n, long R ) {
279         return cache.containsKey( new CacheKey( N, n, R ) );
280     }
281 
282     private static BigInteger getFromCache( Map<CacheKey, BigInteger> cache, long N, long n, long R ) {
283 
284         if ( !cacheContains( cache, N, n, R ) ) {
285             throw new IllegalStateException( "No value stored for N=" + N + ", n=" + n + ", R=" + R );
286         }
287         return cache.get( new CacheKey( N, n, R ) );
288     }
289 
290     private static class CacheKey {
291         private final long n;
292         private final long N;
293         private final long R;
294 
295         public CacheKey( long N, long n, long R ) {
296             this.N = N;
297             this.n = n;
298             this.R = R;
299         }
300 
301         @Override
302         public boolean equals( Object obj ) {
303             if ( obj instanceof CacheKey ) {
304                 CacheKey other = ( CacheKey ) obj;
305                 return N == other.N && n == other.n && R == other.R;
306             } else {
307                 return false;
308             }
309         }
310 
311         @Override
312         public int hashCode() {
313             final int prime = 31;
314             long result = 1;
315             result = prime * result + N;
316             result = prime * result + n;
317             result = prime * result + R;
318             return ( int ) result; // problem: overflows?
319         }
320     }
321 }