1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package ubic.basecode.math;
20
21 import cern.colt.list.DoubleArrayList;
22 import cern.jet.math.Arithmetic;
23 import cern.jet.stat.Probability;
24 import org.slf4j.Logger;
25 import org.slf4j.LoggerFactory;
26
27 import java.math.BigInteger;
28 import java.util.Collections;
29 import java.util.HashMap;
30 import java.util.List;
31 import java.util.Map;
32
33
34
35
36
37
38
39
40
41
42
43
44 public class Wilcoxon {
45
46
47
48
49
50
51 private static final long LIMIT_FOR_APPROXIMATION = 100000L;
52
53 private static final Logger log = LoggerFactory.getLogger( Wilcoxon.class );
54
55
56
57
58
59 public static double exactWilcoxonP( double[] a, double[] b ) {
60 int fullLength = a.length + b.length;
61
62 DoubleArrayList ad = new DoubleArrayList( a );
63 DoubleArrayList bb = new DoubleArrayList( b );
64 ad.addAllOf( bb );
65 DoubleArrayList abR = Rank.rankTransform( ad );
66 int aSum = 0;
67 for ( int i = 0; i < a.length; i++ ) {
68 aSum += abR.get( i );
69 }
70
71 if ( aSum > LIMIT_FOR_APPROXIMATION ) {
72 throw new IllegalArgumentException( "Computation of exact wilcoxon for large values of rank sum will fail." );
73 }
74 return pExact( fullLength, a.length, aSum );
75 }
76
77 public static double exactWilcoxonP( int N, int n, int R ) {
78 if ( R > LIMIT_FOR_APPROXIMATION ) {
79 throw new IllegalArgumentException( "Computation of exact wilcoxon for large values of R will fail." );
80 }
81 return pExact( N, n, R );
82 }
83
84
85
86
87 public static double wilcoxonP( int N, int n, long R ) {
88 return wilcoxonP( N, n, R, false );
89 }
90
91
92
93
94
95
96
97 public static double wilcoxonP( int N, int n, long R, boolean ties ) {
98
99 if ( n > N ) throw new IllegalArgumentException( "n must be less than N (n=" + n + ", N=" + N + ")" );
100
101 if ( n == 0 && N == 0 ) return 1.0;
102
103 if ( ( !ties )
104 && ( ( ( long ) N * ( long ) n <= LIMIT_FOR_APPROXIMATION && n * R <= LIMIT_FOR_APPROXIMATION && ( long ) N
105 * ( long ) n * R <= LIMIT_FOR_APPROXIMATION ) || ( R < N && n * Math.pow( R, 2 ) <= LIMIT_FOR_APPROXIMATION ) ) ) {
106 if ( log.isDebugEnabled() ) log.debug( "Using exact method (" + N * n * R + ")" );
107 return pExact( N, n, ( int ) R );
108 }
109
110 double p = pGaussian( N, n, R );
111
112 if ( p < 0.1 && Math.pow( n, 2 ) * R / N <= 1e5 ) {
113 if ( log.isDebugEnabled() ) log.debug( "Using volume method (" + N * n * R + ")" );
114 return pVolume( N, n, R );
115 }
116
117 if ( log.isDebugEnabled() ) log.debug( "Using gaussian method (" + N * n * R + ")" );
118 return p;
119 }
120
121
122
123
124
125 public static double wilcoxonP( int N, List<Double> ranks ) {
126
127
128
129
130 Collections.sort( ranks );
131 Double p = null;
132 boolean ties = false;
133 for ( Double r : ranks ) {
134 if ( p != null ) {
135 if ( r.equals( p ) ) {
136 ties = true;
137 break;
138 }
139 }
140 p = r;
141 }
142
143 long rankSum = Rank.rankSum( ranks );
144
145 return wilcoxonP( N, ranks.size(), rankSum, ties );
146 }
147
148
149
150
151
152
153
154 private static BigInteger computeA__( int N0, int n0, int R0 ) {
155 Map<CacheKey, BigInteger> cache = new HashMap<>();
156
157 if ( R0 < N0 ) N0 = R0;
158
159 if ( N0 == 0 && n0 == 0 ) return BigInteger.ONE;
160
161 for ( int N = 1; N <= N0; N++ ) {
162
163 long min_n = Math.max( 0, n0 + N - N0 );
164 long max_n = Math.min( n0, N );
165
166 assert max_n >= min_n;
167
168 for ( long n = min_n; n <= max_n; n++ ) {
169
170
171 long bestPossibleRankSum = n * ( n + 1 ) / 2;
172 long worstPossibleRankSum = n * ( 2L * N - n + 1 ) / 2;
173
174
175 long min_r = Math.max( bestPossibleRankSum, R0 - ( N0 + N + 1L ) * ( N0 - N ) / 2 );
176 long max_r = Math.min( worstPossibleRankSum, R0 );
177
178 assert min_r >= 0;
179
180 if ( min_r > max_r ) {
181 throw new IllegalStateException( min_r + " > " + max_r );
182 }
183
184 assert max_r >= min_r : String.format( "max_r %d < min_r %d for N=%d, n=%d, r=%d", max_r, min_r, N0,
185 n0, R0 );
186
187
188 long foo = n * ( 2L * N - n - 1 ) / 2;
189
190
191 long bar = N + ( n - 1 ) * n / 2;
192
193 for ( long r = min_r; r <= max_r; r++ ) {
194
195 if ( n == 0 || n == N || r == bestPossibleRankSum ) {
196 addToCache( cache, N, n, r, BigInteger.ONE );
197
198 } else if ( r > foo ) {
199 addToCache( cache, N, n, r, getFromCache( cache, N - 1, n, foo ).add( getFromCache( cache, N - 1, n - 1, r - N ) ) );
200
201 } else if ( r < bar ) {
202 addToCache( cache, N, n, r, getFromCache( cache, N - 1, n, r ) );
203
204 } else {
205 addToCache( cache, N, n, r, getFromCache( cache, N - 1, n, r ).add( getFromCache( cache, N - 1, n - 1, r - N ) ) );
206 }
207 }
208 }
209 }
210
211 return getFromCache( cache, N0, n0, R0 );
212 }
213
214
215
216
217 private static double pExact( int N, int n, int R ) {
218 return computeA__( N, n, R ).doubleValue() / Arithmetic.binomial( N, n );
219 }
220
221
222
223
224 private static double pGaussian( long N, long n, long R ) {
225 if ( n > N ) throw new IllegalArgumentException( "n must be smaller than N" );
226 double mean = n * ( N + 1 ) / 2.0;
227 double var = n * ( N - n ) * ( N + 1 ) / 12.0;
228 if ( log.isDebugEnabled() ) log.debug( "Mean=" + mean + " Var=" + var + " R=" + R );
229 return Probability.normal( 0.0, var, R - mean );
230 }
231
232
233
234
235 private static double pVolume( int N, int n, long R ) {
236
237 double t = R / ( double ) N;
238
239 if ( t < 0 ) return 0.0;
240 if ( t >= n ) return 1.0;
241 double[] logFactors = new double[n + 1];
242 logFactors[0] = 0.0;
243 logFactors[1] = 0.0;
244 for ( int i = 2; i <= n; i++ ) {
245 logFactors[i] = logFactors[i - 1] + Math.log( i );
246 }
247
248 int kMax = ( int ) t;
249 double[][] C = new double[n][];
250 for ( int i = 0; i < n; i++ ) {
251 C[i] = new double[n + 1];
252 C[0][i] = 0.0;
253 }
254
255 C[0][n] = Math.exp( -logFactors[n] );
256 for ( int k = 1; k <= kMax; k++ ) {
257 for ( int a = 0; a < n; a++ ) {
258 for ( int j = a; j <= n; j++ ) {
259 C[k][a] += C[k - 1][j] * Math.exp( logFactors[j] - logFactors[a] - logFactors[j - a] );
260 }
261 }
262 double b = Math.exp( -logFactors[k] - logFactors[n - 1 - k] ) / n;
263 C[k][n] = k % 2 != 0 ? -b : b;
264 }
265
266 double result = 0.0;
267 for ( int a = 0; a <= n; a++ ) {
268 result += C[kMax][a] * Math.pow( t - kMax, a );
269 }
270 return result;
271
272 }
273
274 private static void addToCache( Map<CacheKey, BigInteger> cache, long N, long n, long R, BigInteger value ) {
275 cache.put( new CacheKey( N, n, R ), value );
276 }
277
278 private static boolean cacheContains( Map<CacheKey, BigInteger> cache, long N, long n, long R ) {
279 return cache.containsKey( new CacheKey( N, n, R ) );
280 }
281
282 private static BigInteger getFromCache( Map<CacheKey, BigInteger> cache, long N, long n, long R ) {
283
284 if ( !cacheContains( cache, N, n, R ) ) {
285 throw new IllegalStateException( "No value stored for N=" + N + ", n=" + n + ", R=" + R );
286 }
287 return cache.get( new CacheKey( N, n, R ) );
288 }
289
290 private static class CacheKey {
291 private final long n;
292 private final long N;
293 private final long R;
294
295 public CacheKey( long N, long n, long R ) {
296 this.N = N;
297 this.n = n;
298 this.R = R;
299 }
300
301 @Override
302 public boolean equals( Object obj ) {
303 if ( obj instanceof CacheKey ) {
304 CacheKey other = ( CacheKey ) obj;
305 return N == other.N && n == other.n && R == other.R;
306 } else {
307 return false;
308 }
309 }
310
311 @Override
312 public int hashCode() {
313 final int prime = 31;
314 long result = 1;
315 result = prime * result + N;
316 result = prime * result + n;
317 result = prime * result + R;
318 return ( int ) result;
319 }
320 }
321 }