View Javadoc
1   /*
2    * The baseCode project
3    *
4    * Copyright (c) 2011 University of British Columbia
5    *
6    * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
7    * the License. You may obtain a copy of the License at
8    *
9    * http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
12   * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
13   * specific language governing permissions and limitations under the License.
14   */
15  package ubic.basecode.math;
16  
17  import java.util.LinkedList;
18  import java.util.Map;
19  import java.util.Queue;
20  import java.util.TreeMap;
21  
22  import cern.colt.list.DoubleArrayList;
23  import cern.colt.matrix.DoubleMatrix1D;
24  import cern.colt.matrix.DoubleMatrix2D;
25  import cern.colt.matrix.impl.DenseDoubleMatrix2D;
26  import cern.jet.stat.Descriptive;
27  import org.apache.commons.lang3.ArrayUtils;
28  import org.apache.commons.math3.analysis.interpolation.LinearInterpolator;
29  import org.apache.commons.math3.analysis.interpolation.LoessInterpolator;
30  import org.apache.commons.math3.analysis.polynomials.PolynomialSplineFunction;
31  import org.apache.commons.math3.exception.OutOfRangeException;
32  import ubic.basecode.math.linearmodels.MeanVarianceEstimator;
33  
34  /**
35   * Methods for moving averages, loess
36   *
37   * @author paul
38   * @author ptan
39   */
40  public class Smooth {
41  
42      /**
43       * Simple moving average that sums the points "backwards".
44       *
45       * @param m
46       * @param windowSize
47       * @return
48       */
49      public static DoubleMatrix1D movingAverage(DoubleMatrix1D m, int windowSize) {
50  
51          Queue<Double> window = new LinkedList<>();
52  
53          double sum = 0.0;
54  
55          assert windowSize > 0;
56  
57          DoubleMatrix1D result = m.like();
58          for (int i = 0; i < m.size(); i++) {
59  
60              double num = m.get(i);
61              sum += num;
62              window.add(num);
63              if (window.size() > windowSize) {
64  
65                  sum -= window.remove();
66              }
67  
68              if (!window.isEmpty()) {
69                  // if ( window.size() == windowSize ) {
70                  result.set(i, sum / window.size());
71              } else {
72                  result.set(i, Double.NaN);
73              }
74          }
75  
76          return result;
77  
78      }
79  
80      /**
81       * Default loess span (This is the default value used by limma-voom)
82       */
83      static final double BANDWIDTH = 0.5;
84  
85      /**
86       * Default number of loess robustness iterations; 0 is probably fine.
87       */
88      static final int ROBUSTNESS_ITERS = 3;
89  
90      /**
91       * @param xy
92       * @return loessFit with default bandwitdh
93       */
94      public static DoubleMatrix2D loessFit(DoubleMatrix2D xy) {
95          return loessFit(xy, BANDWIDTH);
96      }
97  
98      /**
99       * Computes a loess regression line to fit the data
100      *
101      * @param xy        data to be fit
102      * @param bandwidth the span of the smoother (from 2/n to 1 where n is the number of points in xy)
103      * @return loessFit (same dimensions as xy) or null if there are less than 3 data points
104      */
105     public static DoubleMatrix2D loessFit(DoubleMatrix2D xy, double bandwidth) {
106         assert xy != null;
107 
108         DoubleMatrix1D sx = xy.viewColumn(0);
109         DoubleMatrix1D sy = xy.viewColumn(1);
110         Map<Double, Double> map = new TreeMap<>();// to enforce monotonicity
111         for (int i = 0; i < sx.size(); i++) {
112             if (Double.isNaN(sx.get(i)) || Double.isInfinite(sx.get(i)) || Double.isNaN(sy.get(i))
113                     || Double.isInfinite(sy.get(i))) {
114                 continue;
115             }
116             map.put(sx.get(i), sy.get(i));
117         }
118         DoubleMatrix2D xyChecked = new DenseDoubleMatrix2D(map.size(), 2);
119         xyChecked.viewColumn(0).assign(ArrayUtils.toPrimitive(map.keySet().toArray(new Double[0])));
120         xyChecked.viewColumn(1).assign(ArrayUtils.toPrimitive(map.values().toArray(new Double[0])));
121 
122         // in R:
123         // loess(c(1:5),c(1:5)^2,f=0.5,iter=3)
124         // Note: we start to lose some precision here in comparison with R's loess FIXME why? does it matter?
125         DoubleMatrix2D loessFit = new DenseDoubleMatrix2D(xyChecked.rows(), xyChecked.columns());
126 
127         // fit a loess curve
128         LoessInterpolator loessInterpolator = new LoessInterpolator(bandwidth,
129                 ROBUSTNESS_ITERS);
130 
131         double[] loessY = loessInterpolator.smooth(xyChecked.viewColumn(0).toArray(),
132                 xyChecked.viewColumn(1).toArray());
133 
134         loessFit.viewColumn(0).assign(xyChecked.viewColumn(0));
135         loessFit.viewColumn(1).assign(loessY);
136 
137         return loessFit;
138     }
139 
140 
141 
142     /**
143      * Linearlly interpolate values from a given data set
144      *
145      * Similar implementation of R's stats.approxfun(..., rule = 2) where values outside the interval ['min(x)',
146      * 'max(x)'] get the value at the closest data extreme. Also performs sorting based on xTrain.
147      *
148      * @param x the training set of x values
149      * @param y the training set of y values
150      * @param xInterpolate the set of x values to interpolate
151      * @return yInterpolate the interpolated set of y values
152      */
153     public static double[] interpolate( double[] x, double[] y, double[] xInterpolate ) {
154 
155         assert x != null;
156         assert y != null;
157         assert xInterpolate != null;
158         assert x.length == y.length;
159 
160         double[] yInterpolate = new double[xInterpolate.length];
161         LinearInterpolator linearInterpolator = new LinearInterpolator();
162 
163         // make sure that x is strictly increasing
164         DoubleMatrix2D matrix = new DenseDoubleMatrix2D( x.length, 2 );
165         matrix.viewColumn( 0 ).assign( x );
166         matrix.viewColumn( 1 ).assign( y );
167         matrix = matrix.viewSorted( 0 );
168         double[] sortedX = matrix.viewColumn( 0 ).toArray();
169         double[] sortedY = matrix.viewColumn( 1 ).toArray();
170 
171         // make sure x is within the domain
172         DoubleArrayList xList = new DoubleArrayList( sortedX );
173         double x3ListMin = Descriptive.min( xList );
174         double x3ListMax = Descriptive.max( xList );
175         PolynomialSplineFunction fun = linearInterpolator.interpolate( sortedX, sortedY );
176         for ( int i = 0; i < xInterpolate.length; i++ ) {
177             try {
178                 // approx(...,rule=2)
179                 if ( xInterpolate[i] > x3ListMax ) {
180                     yInterpolate[i] = fun.value( x3ListMax );
181                 } else if ( xInterpolate[i] < x3ListMin ) {
182                     yInterpolate[i] = fun.value( x3ListMin );
183                 } else {
184                     yInterpolate[i] = fun.value( xInterpolate[i] );
185                 }
186             } catch ( OutOfRangeException e ) {
187                 // this shouldn't happen anymore
188                 yInterpolate[i] = Double.NaN;
189             }
190         }
191 
192         return yInterpolate;
193     }
194 
195 
196 
197 }