1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package ubic.basecode.math;
20
21 import java.math.BigInteger;
22 import java.util.Collections;
23 import java.util.List;
24 import java.util.Map;
25 import java.util.concurrent.ConcurrentHashMap;
26
27 import org.slf4j.Logger;
28 import org.slf4j.LoggerFactory;
29
30 import cern.colt.list.DoubleArrayList;
31 import cern.jet.math.Arithmetic;
32 import cern.jet.stat.Probability;
33
34
35
36
37
38
39
40
41
42
43
44
45 public class Wilcoxon {
46
47 private static final Map<CacheKey, BigInteger> cache = new ConcurrentHashMap<CacheKey, BigInteger>();
48
49
50
51
52
53 private static long LIMIT_FOR_APPROXIMATION = 100000L;
54
55 private static Logger log = LoggerFactory.getLogger( Wilcoxon.class );
56
57
58
59
60
61
62
63
64
65 public static double exactWilcoxonP( double[] a, double[] b ) {
66 int fullLength = a.length + b.length;
67
68 DoubleArrayList ad = new DoubleArrayList( a );
69 DoubleArrayList bb = new DoubleArrayList( b );
70 ad.addAllOf( bb );
71 DoubleArrayList abR = Rank.rankTransform( ad );
72 int aSum = 0;
73 for ( int i = 0; i < a.length; i++ ) {
74 aSum += abR.get( i );
75 }
76
77 if ( aSum > LIMIT_FOR_APPROXIMATION ) {
78 throw new IllegalArgumentException( "Computation of exact wilcoxon for large values of rank sum will fail." );
79 }
80 return pExact( fullLength, a.length, aSum );
81 }
82
83
84
85
86
87
88
89 public static double exactWilcoxonP( int N, int n, int R ) {
90 if ( R > LIMIT_FOR_APPROXIMATION ) {
91 throw new IllegalArgumentException( "Computation of exact wilcoxon for large values of R will fail." );
92 }
93 return pExact( N, n, R );
94 }
95
96
97
98
99
100
101
102
103
104 public static double wilcoxonP( int N, int n, long R ) {
105 return wilcoxonP( N, n, R, false );
106 }
107
108
109
110
111
112
113
114
115 public static double wilcoxonP( int N, int n, long R, boolean ties ) {
116
117 if ( n > N ) throw new IllegalArgumentException( "n must be less than N (n=" + n + ", N=" + N + ")" );
118
119 if ( n == 0 && N == 0 ) return 1.0;
120
121 if ( ( !ties )
122 && ( ( ( long ) N * ( long ) n <= LIMIT_FOR_APPROXIMATION && n * R <= LIMIT_FOR_APPROXIMATION && ( long ) N
123 * ( long ) n * R <= LIMIT_FOR_APPROXIMATION ) || ( R < N && n * Math.pow( R, 2 ) <= LIMIT_FOR_APPROXIMATION ) ) ) {
124 if ( log.isDebugEnabled() ) log.debug( "Using exact method (" + N * n * R + ")" );
125 return pExact( N, n, ( int ) R );
126 }
127
128 double p = pGaussian( N, n, R );
129
130 if ( p < 0.1 && Math.pow( n, 2 ) * R / N <= 1e5 ) {
131 if ( log.isDebugEnabled() ) log.debug( "Using volume method (" + N * n * R + ")" );
132 return pVolume( N, n, R );
133 }
134
135 if ( log.isDebugEnabled() ) log.debug( "Using gaussian method (" + N * n * R + ")" );
136 return p;
137 }
138
139
140
141
142
143
144 public static double wilcoxonP( int N, List<Double> ranks ) {
145
146
147
148
149 Collections.sort( ranks );
150 Double p = null;
151 boolean ties = false;
152 for ( Double r : ranks ) {
153 if ( p != null ) {
154 if ( r.equals( p ) ) {
155 ties = true;
156 break;
157 }
158 }
159 p = r;
160 }
161
162 long rankSum = Rank.rankSum( ranks );
163
164 return wilcoxonP( N, ranks.size(), rankSum, ties );
165 }
166
167 private static void addToCache( long N, long n, long R, BigInteger value ) {
168 cache.put( new CacheKey( N, n, R ), value );
169 }
170
171
172
173
174
175
176
177 private static boolean cacheContains( long N, long n, long R ) {
178 return cache.containsKey( new CacheKey( N, n, R ) );
179 }
180
181
182
183
184
185
186
187
188
189
190 private static BigInteger computeA__( int N0, int n0, int R0 ) {
191 cache.clear();
192 if ( R0 < N0 ) N0 = R0;
193
194
195
196
197
198 if ( N0 == 0 && n0 == 0 ) return BigInteger.ONE;
199
200 for ( int N = 1; N <= N0; N++ ) {
201 if ( N > 2 ) removeFromCache( N - 2 );
202
203
204 long min_n = Math.max( 0, n0 + N - N0 );
205 long max_n = Math.min( n0, N );
206
207 assert min_n >= 0;
208 assert max_n >= min_n;
209
210 for ( long n = min_n; n <= max_n; n++ ) {
211
212
213 long bestPossibleRankSum = n * ( n + 1 ) / 2;
214 long worstPossibleRankSum = n * ( 2 * N - n + 1 ) / 2;
215
216
217 long min_r = Math.max( bestPossibleRankSum, R0 - ( N0 + N + 1 ) * ( N0 - N ) / 2 );
218 long max_r = Math.min( worstPossibleRankSum, R0 );
219
220 assert min_r >= 0;
221
222 if ( min_r > max_r ) {
223 throw new IllegalStateException( min_r + " > " + max_r );
224 }
225
226 assert max_r >= min_r : String.format( "max_r %d < min_r %d for N=%d, n=%d, r=%d", max_r, min_r, N0,
227 n0, R0 );
228
229
230 long foo = n * ( 2 * N - n - 1 ) / 2;
231
232
233 long bar = N + ( n - 1 ) * n / 2;
234
235 for ( long r = min_r; r <= max_r; r++ ) {
236
237 if ( n == 0 || n == N || r == bestPossibleRankSum ) {
238 addToCache( N, n, r, BigInteger.ONE );
239
240 } else if ( r > foo ) {
241 addToCache( N, n, r, getFromCache( N - 1, n, foo ).add( getFromCache( N - 1, n - 1, r - N ) ) );
242
243 } else if ( r < bar ) {
244 addToCache( N, n, r, getFromCache( N - 1, n, r ) );
245
246 } else {
247 addToCache( N, n, r, getFromCache( N - 1, n, r ).add( getFromCache( N - 1, n - 1, r - N ) ) );
248 }
249 }
250 }
251 }
252 return getFromCache( N0, n0, R0 );
253 }
254
255
256
257
258
259
260
261 private static BigInteger getFromCache( long N, long n, long R ) {
262
263 if ( !cacheContains( N, n, R ) ) {
264 throw new IllegalStateException( "No value stored for N=" + N + ", n=" + n + ", R=" + R );
265 }
266 return cache.get( new CacheKey( N, n, R ) );
267 }
268
269
270
271
272
273
274
275 private static double pExact( int N, int n, int R ) {
276 return computeA__( N, n, R ).doubleValue() / Arithmetic.binomial( N, n );
277 }
278
279
280
281
282
283
284
285 private static double pGaussian( long N, long n, long R ) {
286 if ( n > N ) throw new IllegalArgumentException( "n must be smaller than N" );
287 double mean = n * ( N + 1 ) / 2.0;
288 double var = n * ( N - n ) * ( N + 1 ) / 12.0;
289 if ( log.isDebugEnabled() ) log.debug( "Mean=" + mean + " Var=" + var + " R=" + R );
290 return Probability.normal( 0.0, var, R - mean );
291 }
292
293
294
295
296
297
298
299
300
301 private static double pVolume( int N, int n, long R ) {
302
303 double t = R / ( double ) N;
304
305 if ( t < 0 ) return 0.0;
306 if ( t >= n ) return 1.0;
307 double[] logFactors = new double[n + 1];
308 logFactors[0] = 0.0;
309 logFactors[1] = 0.0;
310 for ( int i = 2; i <= n; i++ ) {
311 logFactors[i] = logFactors[i - 1] + Math.log( i );
312 }
313
314 int kMax = ( int ) t;
315 double[][] C = new double[n][];
316 for ( int i = 0; i < n; i++ ) {
317 C[i] = new double[n + 1];
318 C[0][i] = 0.0;
319 }
320
321 C[0][n] = Math.exp( -logFactors[n] );
322 for ( int k = 1; k <= kMax; k++ ) {
323 for ( int a = 0; a < n; a++ ) {
324 for ( int j = a; j <= n; j++ ) {
325 C[k][a] += C[k - 1][j] * Math.exp( logFactors[j] - logFactors[a] - logFactors[j - a] );
326 }
327 }
328 double b = Math.exp( -logFactors[k] - logFactors[n - 1 - k] ) / n;
329 C[k][n] = k % 2 != 0 ? -b : b;
330 }
331
332 double result = 0.0;
333 for ( int a = 0; a <= n; a++ ) {
334 result += C[kMax][a] * Math.pow( t - kMax, a );
335 }
336 return result;
337
338 }
339
340
341
342
343 private static void removeFromCache( int N ) {
344 cache.remove( N );
345 }
346
347 }
348
349 class CacheKey {
350 private long n;
351 private long N;
352 private long R;
353
354 public CacheKey( long N, long n, long R ) {
355 super();
356 this.N = N;
357 this.n = n;
358 this.R = R;
359 }
360
361 @Override
362 public boolean equals( Object obj ) {
363 CacheKey other = ( CacheKey ) obj;
364
365 if ( N != other.N ) return false;
366 if ( n != other.n ) return false;
367 if ( R != other.R ) return false;
368
369 return true;
370 }
371
372 @Override
373 public int hashCode() {
374 final int prime = 31;
375 long result = 1;
376 result = prime * result + N;
377 result = prime * result + n;
378 result = prime * result + R;
379 return ( int ) result;
380 }
381
382 }