1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package ubic.basecode.math;
20
21 import cern.colt.list.BooleanArrayList;
22 import org.apache.commons.math3.special.Gamma;
23
24 import cern.colt.function.DoubleFunction;
25 import cern.colt.matrix.DoubleMatrix1D;
26 import cern.colt.matrix.impl.DenseDoubleMatrix1D;
27 import cern.jet.math.Functions;
28 import ubic.basecode.dataStructure.matrix.MatrixUtil;
29
30
31
32 import static java.lang.Math.*;
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47 public class SpecFunc {
48
49
50 private static final double SMALL = 1e-8;
51
52
53
54
55
56
57
58
59
60
61 public static double dbinom(double x, double n, double p) {
62
63 if (p < 0 || p > 1 || n < 0) throw new IllegalArgumentException();
64
65 return dbinom_raw(x, n, p, 1 - p);
66 }
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90 public static double dhyper(int x, int r, int b, int n) {
91 double p, q, p1, p2, p3;
92
93 if (r < 0 || b < 0 || n < 0 || n > r + b) throw new IllegalArgumentException();
94
95 if (x < 0) return 0.0;
96
97 if (n < x || r < x || n - x > b) return 0;
98 if (n == 0) return ((x == 0) ? 1 : 0);
99
100 p = ((double) n) / ((double) (r + b));
101 q = ((double) (r + b - n)) / ((double) (r + b));
102
103 p1 = dbinom_raw(x, r, p, q);
104 p2 = dbinom_raw(n - x, b, p, q);
105 p3 = dbinom_raw(n, r + b, p, q);
106
107 return p1 * p2 / p3;
108 }
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123 public static double phyper(int x, int NR, int NB, int n, boolean lowerTail) {
124
125 double d, pd;
126
127 if (NR < 0 || NB < 0 || n < 0 || n > NR + NB) {
128 throw new IllegalArgumentException("Need NR>=0, NB>=0,n>=0,n>NR+NB");
129 }
130
131 if (x * (NR + NB) > n * NR) {
132
133 int oldNB = NB;
134 NB = NR;
135 NR = oldNB;
136 x = n - x - 1;
137 lowerTail = !lowerTail;
138 }
139
140 if (x < 0) return 0.0;
141
142 d = dhyper(x, NR, NB, n);
143 pd = pdhyper(x, NR, NB, n);
144
145 return lowerTail ? d * pd : 1.0 - (d * pd);
146 }
147
148
149
150
151
152 public static double trigammaInverse(double x) {
153 return trigammaInverse(new DenseDoubleMatrix1D(new double[]{x})).get(0);
154 }
155
156
157
158
159
160
161
162 public static DoubleMatrix1D trigammaInverse(DoubleMatrix1D x) {
163
164 if (x == null || x.size() == 0)
165 return null;
166 DoubleMatrix1D y = x.copy();
167
168
169 BooleanArrayList ok = MatrixUtil.matchingCriteria(y, a -> !Double.isNaN(a) && a <= 1e7 && a >= 1.0e-6);
170 y = MatrixUtil.applyToIndicesMatchingCriteria(y, a -> a < 0, a -> Double.NaN);
171 y = MatrixUtil.applyToIndicesMatchingCriteria(y, a -> a > 1e7, a -> 1 / sqrt(a));
172 y = MatrixUtil.applyToIndicesMatchingCriteria(y, a -> a < 1.0e-6, a -> 1 / a);
173
174
175 DoubleMatrix1D yok = MatrixUtil.stripNonOK(y, ok);
176
177
178 double LB = 0.5;
179 DoubleMatrix1D yokop = yok.copy().assign(Functions.inv).assign(Functions.plus(LB));
180
181 double iter = 0;
182 do {
183
184 DoubleMatrix1D tri = yokop.copy().assign(
185 new DoubleFunction() {
186 @Override
187 public final double apply(double a) {
188 return Gamma.trigamma(a);
189 }
190 });
191
192 DoubleMatrix1D tri2 = tri.copy();
193 DoubleMatrix1D dif = tri.assign(yok, Functions.div).assign(Functions.neg)
194 .assign(Functions.plus(1.0)).assign(tri2, Functions.mult)
195 .assign(new DenseDoubleMatrix1D(PolyGamma.psigamma(yokop.toArray(), 2)), Functions.div);
196
197
198 yokop.assign(dif, Functions.plus);
199
200
201 double max = yokop.copy().assign(Functions.inv).assign(dif, Functions.mult).assign(Functions.neg)
202 .aggregate(Functions.max, Functions.identity);
203 if (max < SMALL) break;
204 } while (++iter < 50);
205
206 MatrixUtil.replaceValues(y, ok, yokop);
207
208 return y;
209
210 }
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232 private static double bd0(double x, double np) {
233 double ej, s, s1, v;
234 int j;
235
236 if (Math.abs(x - np) < 0.1 * (x + np)) {
237 v = (x - np) / (x + np);
238 s = (x - np) * v;
239 ej = 2 * x * v;
240 v = v * v;
241 for (j = 1; ; j++) {
242 ej *= v;
243 s1 = s + ej / ((j << 1) + 1);
244 if (s1 == s)
245 return (s1);
246 s = s1;
247 }
248 }
249
250 return (x * Math.log(x / np) + np - x);
251 }
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276 private static double dbinom_raw(double x, double n, double p, double q) {
277 double f, lc;
278
279 if (p == 0) return ((x == 0) ? 1 : 0);
280 if (q == 0) return ((x == n) ? 1 : 0);
281
282 if (x == 0) {
283 if (n == 0) return 1;
284 lc = (p < 0.1) ? -bd0(n, n * q) - n * p : n * Math.log(q);
285 return (Math.exp(lc));
286 }
287 if (x == n) {
288 lc = (q < 0.1) ? -bd0(n, n * p) - n * q : n * Math.log(p);
289 return (Math.exp(lc));
290 }
291 if (x < 0 || x > n) return (0);
292
293 lc = stirlerr(n) - stirlerr(x) - stirlerr(n - x) - bd0(x, n * p) - bd0(n - x, n * q);
294 f = (2 * Math.PI * x * (n - x)) / n;
295
296 return Math.exp(lc) / Math.sqrt(f);
297 }
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323 private static double pdhyper(int x, int NR, int NB, int n) {
324 double sum = 0.0;
325 double term = 1.0;
326
327 while (x > 0.0 && term >= Double.MIN_VALUE * sum) {
328 term *= (double) x * (NB - n + x) / (n + 1 - x) / (NR + 1 - x);
329 sum += term;
330 x--;
331 }
332
333 return 1.0 + sum;
334 }
335
336
337
338
339
340
341
342
343
344
345
346
347
348 private static double stirlerr(double n) {
349
350 double S0 = 0.083333333333333333333;
351 double S1 = 0.00277777777777777777778;
352 double S2 = 0.00079365079365079365079365;
353 double S3 = 0.000595238095238095238095238;
354 double S4 = 0.0008417508417508417508417508;
355
356
357
358
359 double[] sferr_halves = new double[]{0.0,
360 0.1534264097200273452913848,
361 0.0810614667953272582196702,
362 0.0548141210519176538961390,
363 0.0413406959554092940938221,
364 0.03316287351993628748511048,
365 0.02767792568499833914878929,
366 0.02374616365629749597132920,
367 0.02079067210376509311152277,
368 0.01848845053267318523077934,
369 0.01664469118982119216319487,
370 0.01513497322191737887351255,
371 0.01387612882307074799874573,
372 0.01281046524292022692424986,
373 0.01189670994589177009505572,
374 0.01110455975820691732662991,
375 0.010411265261972096497478567,
376 0.009799416126158803298389475,
377 0.009255462182712732917728637,
378 0.008768700134139385462952823,
379 0.008330563433362871256469318,
380 0.007934114564314020547248100,
381 0.007573675487951840794972024,
382 0.007244554301320383179543912,
383 0.006942840107209529865664152,
384 0.006665247032707682442354394,
385 0.006408994188004207068439631,
386 0.006171712263039457647532867,
387 0.005951370112758847735624416,
388 0.005746216513010115682023589,
389 0.005554733551962801371038690
390
391 };
392
393 double nn;
394
395 if (n <= 15.0) {
396 nn = n + n;
397 if (nn == (int) nn) return (sferr_halves[(int) nn]);
398 return (Gamma.logGamma(n + 1.) - (n + 0.5) * Math.log(n) + n - Constants.M_LN_SQRT_2PI);
399 }
400
401 nn = n * n;
402 if (n > 500) return ((S0 - S1 / nn) / n);
403 if (n > 80) return ((S0 - (S1 - S2 / nn) / nn) / n);
404 if (n > 35) return ((S0 - (S1 - (S2 - S3 / nn) / nn) / nn) / n);
405
406 return ((S0 - (S1 - (S2 - (S3 - S4 / nn) / nn) / nn) / nn) / n);
407 }
408
409 }
410
411
412 class PolyGamma {
413 private static final double klog10Of2 = log10(2),
414 kDefaultWDTol = max(pow(2, -53), 0.5e-18);
415 private static final int kMaxValue = 100,
416 DBL_MANT_DIG = 53,
417 DBL_MIN_EXP = -1021;
418 private static final String sErrorDomain = "Math Error: DOMAIN";
419
420
421 static private double bvalues[] = {
422 1.00000000000000000e+00,
423 -5.00000000000000000e-01,
424 1.66666666666666667e-01,
425 -3.33333333333333333e-02,
426 2.38095238095238095e-02,
427 -3.33333333333333333e-02,
428 7.57575757575757576e-02,
429 -2.53113553113553114e-01,
430 1.16666666666666667e+00,
431 -7.09215686274509804e+00,
432 5.49711779448621554e+01,
433 -5.29124242424242424e+02,
434 6.19212318840579710e+03,
435 -8.65802531135531136e+04,
436 1.42551716666666667e+06,
437 -2.72982310678160920e+07,
438 6.01580873900642368e+08,
439 -1.51163157670921569e+10,
440 4.29614643061166667e+11,
441 -1.37116552050883328e+13,
442 4.88332318973593167e+14,
443 -1.92965793419400681e+16
444 };
445
446 public static final double[] dpsifn(double x, int n, int kode, int m) {
447 double ans[] = new double[n + 1];
448 int i, j, k, mm, mx, nn, np, nx, fn;
449 double arg, den, elim, eps, fln, fx, rln, rxsq;
450 double s, slope, t, ta, tk, tol, tols, tss, tst;
451 double tt, t1, t2, xdmln, xdmy = 0, xinc = 0, xln = 0, xm, xmin;
452 double xq, yint;
453 double trm[] = new double[23], trmr[] = new double[kMaxValue + 1];
454 boolean flag1 = false;
455
456 if (n < 0 || kode < 1 || kode > 2 || m < 1)
457 return null;
458
459 if (x <= 0.) {
460
461
462
463
464 if (x == (long) x) {
465
466 for (j = 0; j < m; j++)
467 ans[j] = ((j + n) % 2 == 1) ? Double.POSITIVE_INFINITY : Double.NaN;
468 return ans;
469 }
470 dpsifn(1. - x, n, 1, m);
471
472
473
474
475
476
477 if (m > 1 || n > 3)
478 return null;
479 x *= PI;
480 if (n == 0)
481 tt = cos(x) / sin(x);
482 else if (n == 1)
483 tt = -1 / pow(sin(x), 2);
484 else if (n == 2)
485 tt = 2 * cos(x) / pow(sin(x), 3);
486 else if (n == 3)
487 tt = -2 * (2 * pow(cos(x), 2) + 1) / pow(sin(x), 4);
488 else
489 tt = Double.NaN;
490
491
492 s = (n % 2 == 1) ? -1. : 1.;
493
494
495
496
497 t1 = t2 = s = 1.;
498 for (k = 0, j = k - n; j < m; k++, j++, s = -s) {
499
500 t1 *= PI;
501 if (k >= 2)
502 t2 *= k;
503 if (j >= 0)
504 ans[j] = s * (ans[j] + t1 / t2 * tt);
505 }
506 if (n == 0 && kode == 2)
507 ans[0] += xln;
508 return ans;
509 }
510
511
512 mm = m;
513 nx = -DBL_MIN_EXP;
514
515
516
517
518
519
520 elim = 2.302 * (nx * klog10Of2 - 3.0);
521 xln = log(x);
522 xdmln = xln;
523 for (; ; ) {
524 nn = n + mm - 1;
525 fn = nn;
526 t = (fn + 1) * xln;
527
528
529
530
531 if (abs(t) > elim) {
532 if (t <= 0.0)
533 return null;
534 } else {
535 if (x < kDefaultWDTol) {
536 ans[0] = pow(x, -n - 1.0);
537 if (mm != 1) {
538 for (k = 1; k < mm; k++)
539 ans[k] = ans[k - 1] / x;
540 }
541 if (n == 0 && kode == 2)
542 ans[0] += xln;
543 return ans;
544 }
545
546
547
548 rln = klog10Of2 * DBL_MANT_DIG;
549 rln = min(rln, 18.06);
550
551 fln = max(rln, 3.0) - 3.0;
552 yint = 3.50 + 0.40 * fln;
553 slope = 0.21 + fln * (0.0006038 * fln + 0.008677);
554 xm = yint + slope * fn;
555 mx = (int) xm + 1;
556 xmin = mx;
557 if (n != 0) {
558 xm = -2.302 * rln - min(0.0, xln);
559 arg = xm / n;
560 arg = min(0.0, arg);
561 eps = exp(arg);
562 xm = 1.0 - eps;
563 if (abs(arg) < 1.0e-3)
564 xm = -arg;
565 fln = x * xm / eps;
566 xm = xmin - x;
567 if (xm > 7.0 && fln < 15.0)
568 break;
569 }
570 xdmy = x;
571 xdmln = xln;
572 xinc = 0.0;
573 if (x < xmin) {
574 nx = (int) x;
575 xinc = xmin - nx;
576 xdmy = x + xinc;
577 xdmln = log(xdmy);
578 }
579
580
581
582 t = fn * xdmln;
583 t1 = xdmln + xdmln;
584 t2 = t + xdmln;
585
586 tk = max(abs(t), max(abs(t1), abs(t2)));
587 if (tk <= elim) {
588 flag1 = true;
589 break;
590 }
591 }
592
593
594 mm--;
595 ans[mm] = 0.0;
596 if (mm == 0)
597 return ans;
598 }
599
600 if (!flag1) {
601 nn = (int) fln + 1;
602 np = n + 1;
603 t1 = (n + 1) * xln;
604 t = exp(-t1);
605 s = t;
606 den = x;
607 for (i = 1; i <= nn; i++) {
608 den = den + 1.0;
609 trm[i] = pow(den, -np);
610 s += trm[i];
611 }
612 ans[0] = s;
613 if (n == 0 && kode == 2)
614 ans[0] = s + xln;
615
616 if (mm != 1) {
617
618 tol = kDefaultWDTol / 5.0;
619 for (j = 1; j < mm; j++) {
620 t = t / x;
621 s = t;
622 tols = t * tol;
623 den = x;
624 for (i = 1; i <= nn; i++) {
625 den += 1.0;
626 trm[i] /= den;
627 s += trm[i];
628 if (trm[i] < tols)
629 break;
630 }
631 ans[j] = s;
632 }
633 }
634 return ans;
635 }
636
637 tss = exp(-t);
638 tt = 0.5 / xdmy;
639 t1 = tt;
640 tst = kDefaultWDTol * tt;
641 if (nn != 0)
642 t1 = tt + 1.0 / fn;
643 rxsq = 1.0 / (xdmy * xdmy);
644 ta = 0.5 * rxsq;
645 t = (fn + 1) * ta;
646 s = t * bvalues[2];
647
648 if (abs(s) >= tst) {
649 tk = 2.0;
650 for (k = 4; k <= 22; k++) {
651 t = t * ((tk + fn + 1) / (tk + 1.0)) * ((tk + fn) / (tk + 2.0)) * rxsq;
652 trm[k] = t * bvalues[k - 1];
653
654 if (abs(trm[k]) < tst)
655 break;
656 s += trm[k];
657 tk += 2.0;
658 }
659 }
660 s = (s + t1) * tss;
661 if (xinc != 0.0) {
662
663 nx = (int) xinc;
664 np = nn + 1;
665 if (nx > kMaxValue)
666 return null;
667 if (nn == 0) {
668 for (i = 1; i <= nx; i++)
669 s += 1.0 / (x + nx - i);
670
671 if (kode != 2)
672 ans[0] = s - xdmln;
673 else if (xdmy != x) {
674 xq = xdmy / x;
675 ans[0] = s - log(xq);
676 }
677 return ans;
678 }
679 xm = xinc - 1.0;
680 fx = x + xm;
681
682
683
684 for (i = 1; i <= nx; i++) {
685 trmr[i] = pow(fx, -np);
686 s += trmr[i];
687 xm -= 1.0;
688 fx = x + xm;
689 }
690 }
691 ans[mm - 1] = s;
692 if (fn == 0) {
693 if (kode != 2)
694 ans[0] = s - xdmln;
695 else if (xdmy != x) {
696 xq = xdmy / x;
697 ans[0] = s - log(xq);
698 }
699 return ans;
700 }
701
702
703
704 for (j = 2; j <= mm; j++) {
705 fn--;
706 tss *= xdmy;
707 t1 = tt;
708 if (fn != 0)
709 t1 = tt + 1.0 / fn;
710 t = (fn + 1) * ta;
711 s = t * bvalues[2];
712 if (abs(s) >= tst) {
713 tk = 4 + fn;
714 for (k = 4; k <= 22; k++) {
715 trm[k] = trm[k] * (fn + 1) / tk;
716 if (abs(trm[k]) < tst)
717 break;
718 s += trm[k];
719 tk += 2.0;
720 }
721 }
722 s = (s + t1) * tss;
723
724 if (xinc != 0.0) {
725 if (fn == 0) {
726 for (i = 1; i <= nx; i++)
727 s += 1.0 / (x + nx - i);
728
729 if (kode != 2)
730 ans[0] = s - xdmln;
731 else if (xdmy != x) {
732 xq = xdmy / x;
733 ans[0] = s - log(xq);
734 }
735 }
736 xm = xinc - 1.0;
737 fx = x + xm;
738 for (i = 1; i <= nx; i++) {
739 trmr[i] = trmr[i] * fx;
740 s += trmr[i];
741 xm -= 1.0;
742 fx = x + xm;
743 }
744 }
745 ans[mm - j] = s;
746 if (fn == 0) {
747 if (kode != 2)
748 ans[0] = s - xdmln;
749 else if (xdmy != x) {
750 xq = xdmy / x;
751 ans[0] = s - log(xq);
752 }
753 return ans;
754 }
755 }
756 return ans;
757 }
758
759 public static final double psigamma(double x, int n) {
760
761 double[] ans;
762
763
764
765 ans = dpsifn(x, n, 1, 1);
766 if (ans == null)
767 return Double.NaN;
768
769 double result = -ans[0];
770 for (int k = 1; k <= n; k++)
771 result *= (-k);
772 return result;
773 }
774
775 public static final double digamma(double x) {
776 double ans[] = dpsifn(x, 0, 1, 1);
777 if (ans == null)
778 throw new ArithmeticException(sErrorDomain);
779 return -ans[0];
780 }
781
782 public static final double trigamma(double x) {
783 double ans[] = dpsifn(x, 1, 1, 1);
784 if (ans == null)
785 throw new ArithmeticException(sErrorDomain);
786 return ans[0];
787 }
788
789 public static final double tetragamma(double x) {
790 double ans[] = dpsifn(x, 2, 1, 1);
791 if (ans == null)
792 throw new ArithmeticException(sErrorDomain);
793 return -2.0 * ans[0];
794 }
795
796 public static final double pentagamma(double x) {
797 double ans[] = dpsifn(x, 3, 1, 1);
798 if (ans == null)
799 throw new ArithmeticException(sErrorDomain);
800 return 6.0 * ans[0];
801 }
802
803 public static final double[] psigamma(double[] x, int deriv) {
804 int n = x.length;
805 double[] r = new double[n];
806 for (int i = 0; i < n; i++)
807 r[i] = psigamma(x[i], deriv);
808 return r;
809 }
810
811 public static final double[] digamma(double[] x) {
812 return psigamma(x, 0);
813 }
814
815 public static final double[] trigamma(double[] x) {
816 return psigamma(x, 1);
817 }
818
819 public static final double[] tetragamma(double[] x) {
820 return psigamma(x, 2);
821 }
822
823 public static final double[] pentagamma(double[] x) {
824 return psigamma(x, 3);
825 }
826
827
828
829
830
831
832
833
834
835
836 public static final double lmvpsigammafn(double a, int p, int deriv) {
837 double sum = 0;
838 for (int j = 1; j <= p; j++)
839 sum += log(psigamma(a + (1 - j) / 2.0, deriv));
840 return sum;
841 }
842
843 }