You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by er...@apache.org on 2012/08/18 03:09:25 UTC

svn commit: r1374492 - /commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optimization/fitting/GaussianFitter.java

Author: erans
Date: Sat Aug 18 01:09:25 2012
New Revision: 1374492

URL: http://svn.apache.org/viewvc?rev=1374492&view=rev
Log:
Code cleanup: Moved all computations to the constructor, allowing the class
to be immutable.

Modified:
    commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optimization/fitting/GaussianFitter.java

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optimization/fitting/GaussianFitter.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optimization/fitting/GaussianFitter.java?rev=1374492&r1=1374491&r2=1374492&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optimization/fitting/GaussianFitter.java (original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optimization/fitting/GaussianFitter.java Sat Aug 18 01:09:25 2012
@@ -30,6 +30,7 @@ import org.apache.commons.math3.exceptio
 import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer;
 import org.apache.commons.math3.optimization.fitting.CurveFitter;
 import org.apache.commons.math3.optimization.fitting.WeightedObservedPoint;
+import org.apache.commons.math3.util.FastMath;
 
 /**
  * Fits points to a {@link
@@ -127,15 +128,22 @@ public class GaussianFitter extends Curv
      * based on the specified observed points.
      */
     public static class ParameterGuesser {
-        /** Observed points. */
-        private final WeightedObservedPoint[] observations;
-        /** Resulting guessed parameters. */
-        private double[] parameters;
+        /** Normalization factor. */
+        private final double norm;
+        /** Mean. */
+        private final double mean;
+        /** Standard deviation. */
+        private final double sigma;
 
         /**
          * Constructs instance with the specified observed points.
          *
-         * @param observations observed points upon which should base guess
+         * @param observations Observed points from which to guess the
+         * parameters of the Gaussian.
+         * @throws NullArgumentException if {@code observations} is
+         * {@code null}.
+         * @throws NumberIsTooSmallException if there are less than 3
+         * observations.
          */
         public ParameterGuesser(WeightedObservedPoint[] observations) {
             if (observations == null) {
@@ -144,47 +152,101 @@ public class GaussianFitter extends Curv
             if (observations.length < 3) {
                 throw new NumberIsTooSmallException(observations.length, 3, true);
             }
-            this.observations = observations.clone();
+
+            final WeightedObservedPoint[] sorted = sortObservations(observations);
+            final double[] params = basicGuess(sorted);
+
+            norm = params[0];
+            mean = params[1];
+            sigma = params[2];
         }
 
         /**
-         * Guesses the parameters based on the observed points.
+         * Gets an estimation of the parameters.
          *
-         * @return the guessed parameters: norm, mean and sigma.
+         * @return the guessed parameters, in the following order:
+         * <ul>
+         *  <li>Normalization factor</li>
+         *  <li>Mean</li>
+         *  <li>Standard deviation</li>
+         * </ul>
          */
         public double[] guess() {
-            if (parameters == null) {
-                parameters = basicGuess(observations);
-            }
-            return parameters.clone();
+            return new double[] { norm, mean, sigma };
+        }
+
+        /**
+         * Sort the observations.
+         *
+         * @param unsorted Input observations.
+         * @return the input observations, sorted.
+         */
+        private WeightedObservedPoint[] sortObservations(WeightedObservedPoint[] unsorted) {
+            final WeightedObservedPoint[] observations = unsorted.clone();
+            final Comparator<WeightedObservedPoint> cmp
+                = new Comparator<WeightedObservedPoint>() {
+                public int compare(WeightedObservedPoint p1,
+                                   WeightedObservedPoint p2) {
+                    if (p1 == null && p2 == null) {
+                        return 0;
+                    }
+                    if (p1 == null) {
+                        return -1;
+                    }
+                    if (p2 == null) {
+                        return 1;
+                    }
+                    if (p1.getX() < p2.getX()) {
+                        return -1;
+                    }
+                    if (p1.getX() > p2.getX()) {
+                        return 1;
+                    }
+                    if (p1.getY() < p2.getY()) {
+                        return -1;
+                    }
+                    if (p1.getY() > p2.getY()) {
+                        return 1;
+                    }
+                    if (p1.getWeight() < p2.getWeight()) {
+                        return -1;
+                    }
+                    if (p1.getWeight() > p2.getWeight()) {
+                        return 1;
+                    }
+                    return 0;
+                }
+            };
+
+            Arrays.sort(observations, cmp);
+            return observations;
         }
 
         /**
          * Guesses the parameters based on the specified observed points.
          *
-         * @param points Observed points upon which should base guess.
-         * @return the guessed parameters: norm, mean and sigma.
+         * @param points Observed points, sorted.
+         * @return the guessed parameters (normalization factor, mean and
+         * sigma).
          */
         private double[] basicGuess(WeightedObservedPoint[] points) {
-            Arrays.sort(points, createWeightedObservedPointComparator());
-            double[] params = new double[3];
-
-            int maxYIdx = findMaxY(points);
-            params[0] = points[maxYIdx].getY();
-            params[1] = points[maxYIdx].getX();
+            final int maxYIdx = findMaxY(points);
+            final double n = points[maxYIdx].getY();
+            final double m = points[maxYIdx].getX();
 
             double fwhmApprox;
             try {
-                double halfY = params[0] + ((params[1] - params[0]) / 2.0);
-                double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
-                double fwhmX2 = interpolateXAtY(points, maxYIdx, +1, halfY);
+                final double halfY = n + ((m - n) / 2);
+                final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
+                final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY);
                 fwhmApprox = fwhmX2 - fwhmX1;
             } catch (OutOfRangeException e) {
+                // TODO: Exceptions should not be used for flow control.
                 fwhmApprox = points[points.length - 1].getX() - points[0].getX();
             }
-            params[2] = fwhmApprox / (2.0 * Math.sqrt(2.0 * Math.log(2.0)));
+            final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2)));
 
-            return params;
+            return new double[] { n, m, s };
         }
 
         /**
@@ -208,33 +270,35 @@ public class GaussianFitter extends Curv
          * specified Y.
          *
          * @param points Points to use for interpolation.
-         * @param startIdx Index within points from which to start search for
-         *  interpolation bounds points.
-         * @param idxStep Index step for search for interpolation bounds points.
+         * @param startIdx Index within points from which to start the search for
+         * interpolation bounds points.
+         * @param idxStep Index step for searching interpolation bounds points.
          * @param y Y value for which X should be determined.
-         * @return the value of X at the specified Y.
+         * @return the value of X for the specified Y.
          * @throws ZeroException if {@code idxStep} is 0.
          * @throws OutOfRangeException if specified {@code y} is not within the
          * range of the specified {@code points}.
          */
         private double interpolateXAtY(WeightedObservedPoint[] points,
-                                       int startIdx, int idxStep, double y)
+                                       int startIdx,
+                                       int idxStep,
+                                       double y)
             throws OutOfRangeException {
             if (idxStep == 0) {
                 throw new ZeroException();
             }
-            WeightedObservedPoint[] twoPoints = getInterpolationPointsForY(points, startIdx, idxStep, y);
-            WeightedObservedPoint pointA = twoPoints[0];
-            WeightedObservedPoint pointB = twoPoints[1];
-            if (pointA.getY() == y) {
-                return pointA.getX();
+            final WeightedObservedPoint[] twoPoints
+                = getInterpolationPointsForY(points, startIdx, idxStep, y);
+            final WeightedObservedPoint p1 = twoPoints[0];
+            final WeightedObservedPoint p2 = twoPoints[1];
+            if (p1.getY() == y) {
+                return p1.getX();
             }
-            if (pointB.getY() == y) {
-                return pointB.getX();
+            if (p2.getY() == y) {
+                return p2.getX();
             }
-            return pointA.getX() +
-                   (((y - pointA.getY()) * (pointB.getX() - pointA.getX())) /
-                    (pointB.getY() - pointA.getY()));
+            return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
+                                (p2.getY() - p1.getY()));
         }
 
         /**
@@ -253,84 +317,50 @@ public class GaussianFitter extends Curv
          * range of the specified {@code points}.
          */
         private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
-                                                                   int startIdx, int idxStep, double y)
+                                                                   int startIdx,
+                                                                   int idxStep,
+                                                                   double y)
             throws OutOfRangeException {
             if (idxStep == 0) {
                 throw new ZeroException();
             }
             for (int i = startIdx;
-                 (idxStep < 0) ? (i + idxStep >= 0) : (i + idxStep < points.length);
+                 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
                  i += idxStep) {
-                if (isBetween(y, points[i].getY(), points[i + idxStep].getY())) {
-                    return (idxStep < 0) ?
-                           new WeightedObservedPoint[] { points[i + idxStep], points[i] } :
-                           new WeightedObservedPoint[] { points[i], points[i + idxStep] };
+                final WeightedObservedPoint p1 = points[i];
+                final WeightedObservedPoint p2 = points[i + idxStep];
+                if (isBetween(y, p1.getY(), p2.getY())) {
+                    if (idxStep < 0) {
+                        return new WeightedObservedPoint[] { p2, p1 };
+                    } else {
+                        return new WeightedObservedPoint[] { p1, p2 };
+                    }
                 }
             }
 
-            double minY = Double.POSITIVE_INFINITY;
-            double maxY = Double.NEGATIVE_INFINITY;
-            for (final WeightedObservedPoint point : points) {
-                minY = Math.min(minY, point.getY());
-                maxY = Math.max(maxY, point.getY());
-            }
-            throw new OutOfRangeException(y, minY, maxY);
+            // Boundaries are replaced by dummy values because the raised
+            // exception is caught and the message never displayed.
+            // TODO: Exceptions should not be used for flow control.
+            throw new OutOfRangeException(y,
+                                          Double.NEGATIVE_INFINITY,
+                                          Double.POSITIVE_INFINITY);
         }
 
         /**
          * Determines whether a value is between two other values.
          *
-         * @param value Value to determine whether is between {@code boundary1}
+         * @param value Value to test whether it is between {@code boundary1}
          * and {@code boundary2}.
          * @param boundary1 One end of the range.
          * @param boundary2 Other end of the range.
          * @return {@code true} if {@code value} is between {@code boundary1} and
          * {@code boundary2} (inclusive), {@code false} otherwise.
          */
-        private boolean isBetween(double value, double boundary1, double boundary2) {
+        private boolean isBetween(double value,
+                                  double boundary1,
+                                  double boundary2) {
             return (value >= boundary1 && value <= boundary2) ||
-                   (value >= boundary2 && value <= boundary1);
-        }
-
-        /**
-         * Factory method creating {@code Comparator} for comparing
-         * {@code WeightedObservedPoint} instances.
-         *
-         * @return the new {@code Comparator} instance.
-         */
-        private Comparator<WeightedObservedPoint> createWeightedObservedPointComparator() {
-            return new Comparator<WeightedObservedPoint>() {
-                public int compare(WeightedObservedPoint p1, WeightedObservedPoint p2) {
-                    if (p1 == null && p2 == null) {
-                        return 0;
-                    }
-                    if (p1 == null) {
-                        return -1;
-                    }
-                    if (p2 == null) {
-                        return 1;
-                    }
-                    if (p1.getX() < p2.getX()) {
-                        return -1;
-                    }
-                    if (p1.getX() > p2.getX()) {
-                        return 1;
-                    }
-                    if (p1.getY() < p2.getY()) {
-                        return -1;
-                    }
-                    if (p1.getY() > p2.getY()) {
-                        return 1;
-                    }
-                    if (p1.getWeight() < p2.getWeight()) {
-                        return -1;
-                    }
-                    if (p1.getWeight() > p2.getWeight()) {
-                        return 1;
-                    }
-                    return 0;
-                }
-            };
+                (value >= boundary2 && value <= boundary1);
         }
     }
 }