You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by er...@apache.org on 2021/05/28 22:49:50 UTC
[commons-math] 02/02: MATH-1172: "SimpleCurveFitter" as parent
class for curve fitter implementations.
This is an automated email from the ASF dual-hosted git repository.
erans pushed a commit to branch modularized_master
in repository https://gitbox.apache.org/repos/asf/commons-math.git
commit 1d9670cb12613a2d8c27ad318237e000668f8836
Author: Gilles Sadowski <gi...@gmail.com>
AuthorDate: Sat May 29 00:34:28 2021 +0200
MATH-1172: "SimpleCurveFitter" as parent class for curve fitter implementations.
---
.../math4/legacy/fitting/GaussianCurveFitter.java | 274 ++-------------------
.../math4/legacy/fitting/HarmonicCurveFitter.java | 146 ++---------
.../legacy/fitting/PolynomialCurveFitter.java | 70 +-----
.../math4/legacy/fitting/SimpleCurveFitter.java | 213 +++++++++++++++-
.../legacy/fitting/GaussianCurveFitterTest.java | 20 +-
.../legacy/fitting/HarmonicCurveFitterTest.java | 12 +-
.../legacy/fitting/PolynomialCurveFitterTest.java | 10 +-
7 files changed, 262 insertions(+), 483 deletions(-)
diff --git a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitter.java b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitter.java
index 85378c9..69a4802 100644
--- a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitter.java
+++ b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitter.java
@@ -16,22 +16,15 @@
*/
package org.apache.commons.math4.legacy.fitting;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.Comparator;
import java.util.List;
+import java.util.Collection;
import org.apache.commons.math4.legacy.analysis.function.Gaussian;
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.legacy.exception.NullArgumentException;
import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
import org.apache.commons.math4.legacy.exception.OutOfRangeException;
-import org.apache.commons.math4.legacy.exception.ZeroException;
import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
-import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder;
-import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
-import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
import org.apache.commons.math4.legacy.util.FastMath;
/**
@@ -69,7 +62,7 @@ import org.apache.commons.math4.legacy.util.FastMath;
*
* @since 3.3
*/
-public class GaussianCurveFitter extends AbstractCurveFitter {
+public class GaussianCurveFitter extends SimpleCurveFitter {
/** Parametric function to be fitted. */
private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() {
/** {@inheritDoc} */
@@ -98,10 +91,6 @@ public class GaussianCurveFitter extends AbstractCurveFitter {
return v;
}
};
- /** Initial guess. */
- private final double[] initialGuess;
- /** Maximum number of iterations of the optimization algorithm. */
- private final int maxIter;
/**
* Constructor used by the factory methods.
@@ -112,8 +101,7 @@ public class GaussianCurveFitter extends AbstractCurveFitter {
*/
private GaussianCurveFitter(double[] initialGuess,
int maxIter) {
- this.initialGuess = initialGuess;
- this.maxIter = maxIter;
+ super(FUNCTION, initialGuess, new ParameterGuesser(), maxIter);
}
/**
@@ -132,86 +120,27 @@ public class GaussianCurveFitter extends AbstractCurveFitter {
}
/**
- * Configure the start point (initial guess).
- * @param newStart new start point (initial guess)
- * @return a new instance.
- */
- public GaussianCurveFitter withStartPoint(double[] newStart) {
- return new GaussianCurveFitter(newStart.clone(),
- maxIter);
- }
-
- /**
- * Configure the maximum number of iterations.
- * @param newMaxIter maximum number of iterations
- * @return a new instance.
- */
- public GaussianCurveFitter withMaxIterations(int newMaxIter) {
- return new GaussianCurveFitter(initialGuess,
- newMaxIter);
- }
-
- /** {@inheritDoc} */
- @Override
- protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
-
- // Prepare least-squares problem.
- final int len = observations.size();
- final double[] target = new double[len];
- final double[] weights = new double[len];
-
- int i = 0;
- for (WeightedObservedPoint obs : observations) {
- target[i] = obs.getY();
- weights[i] = obs.getWeight();
- ++i;
- }
-
- final AbstractCurveFitter.TheoreticalValuesFunction model =
- new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations);
-
- final double[] startPoint = initialGuess != null ?
- initialGuess :
- // Compute estimation.
- new ParameterGuesser(observations).guess();
-
- // Return a new least squares problem set up to fit a Gaussian curve to the
- // observed points.
- return new LeastSquaresBuilder().
- maxEvaluations(Integer.MAX_VALUE).
- maxIterations(maxIter).
- start(startPoint).
- target(target).
- weight(new DiagonalMatrix(weights)).
- model(model.getModelFunction(), model.getModelFunctionJacobian()).
- build();
-
- }
-
- /**
* Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
* of a {@link org.apache.commons.math4.legacy.analysis.function.Gaussian.Parametric}
* based on the specified observed points.
*/
- public static class ParameterGuesser {
- /** Normalization factor. */
- private final double norm;
- /** Mean. */
- private final double mean;
- /** Standard deviation. */
- private final double sigma;
-
+ public static class ParameterGuesser extends SimpleCurveFitter.ParameterGuesser {
/**
- * Constructs instance with the specified observed points.
+ * {@inheritDoc}
*
- * @param observations Observed points from which to guess the
- * parameters of the Gaussian.
+ * @return the guessed parameters, in the following order:
+ * <ul>
+ * <li>Normalization factor</li>
+ * <li>Mean</li>
+ * <li>Standard deviation</li>
+ * </ul>
* @throws NullArgumentException if {@code observations} is
* {@code null}.
* @throws NumberIsTooSmallException if there are less than 3
* observations.
*/
- public ParameterGuesser(Collection<WeightedObservedPoint> observations) {
+ @Override
+ public double[] guess(Collection<WeightedObservedPoint> observations) {
if (observations == null) {
throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
}
@@ -220,68 +149,7 @@ public class GaussianCurveFitter extends AbstractCurveFitter {
}
final List<WeightedObservedPoint> sorted = sortObservations(observations);
- final double[] params = basicGuess(sorted.toArray(new WeightedObservedPoint[0]));
-
- norm = params[0];
- mean = params[1];
- sigma = params[2];
- }
-
- /**
- * Gets an estimation of the parameters.
- *
- * @return the guessed parameters, in the following order:
- * <ul>
- * <li>Normalization factor</li>
- * <li>Mean</li>
- * <li>Standard deviation</li>
- * </ul>
- */
- public double[] guess() {
- return new double[] { norm, mean, sigma };
- }
-
- /**
- * Sort the observations.
- *
- * @param unsorted Input observations.
- * @return the input observations, sorted.
- */
- private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
- final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
-
- final Comparator<WeightedObservedPoint> cmp = new Comparator<WeightedObservedPoint>() {
- /** {@inheritDoc} */
- @Override
- public int compare(WeightedObservedPoint p1,
- WeightedObservedPoint p2) {
- if (p1 == null && p2 == null) {
- return 0;
- }
- if (p1 == null) {
- return -1;
- }
- if (p2 == null) {
- return 1;
- }
- int comp = Double.compare(p1.getX(), p2.getX());
- if (comp != 0) {
- return comp;
- }
- comp = Double.compare(p1.getY(), p2.getY());
- if (comp != 0) {
- return comp;
- }
- comp = Double.compare(p1.getWeight(), p2.getWeight());
- if (comp != 0) {
- return comp;
- }
- return 0;
- }
- };
-
- Collections.sort(observations, cmp);
- return observations;
+ return basicGuess(sorted.toArray(new WeightedObservedPoint[0]));
}
/**
@@ -309,119 +177,5 @@ public class GaussianCurveFitter extends AbstractCurveFitter {
return new double[] { n, points[maxYIdx].getX(), s };
}
-
- /**
- * Finds index of point in specified points with the largest Y.
- *
- * @param points Points to search.
- * @return the index in specified points array.
- */
- private int findMaxY(WeightedObservedPoint[] points) {
- int maxYIdx = 0;
- for (int i = 1; i < points.length; i++) {
- if (points[i].getY() > points[maxYIdx].getY()) {
- maxYIdx = i;
- }
- }
- return maxYIdx;
- }
-
- /**
- * Interpolates using the specified points to determine X at the
- * specified Y.
- *
- * @param points Points to use for interpolation.
- * @param startIdx Index within points from which to start the search for
- * interpolation bounds points.
- * @param idxStep Index step for searching interpolation bounds points.
- * @param y Y value for which X should be determined.
- * @return the value of X for the specified Y.
- * @throws ZeroException if {@code idxStep} is 0.
- * @throws OutOfRangeException if specified {@code y} is not within the
- * range of the specified {@code points}.
- */
- private double interpolateXAtY(WeightedObservedPoint[] points,
- int startIdx,
- int idxStep,
- double y)
- throws OutOfRangeException {
- if (idxStep == 0) {
- throw new ZeroException();
- }
- final WeightedObservedPoint[] twoPoints
- = getInterpolationPointsForY(points, startIdx, idxStep, y);
- final WeightedObservedPoint p1 = twoPoints[0];
- final WeightedObservedPoint p2 = twoPoints[1];
- if (p1.getY() == y) {
- return p1.getX();
- }
- if (p2.getY() == y) {
- return p2.getX();
- }
- return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
- (p2.getY() - p1.getY()));
- }
-
- /**
- * Gets the two bounding interpolation points from the specified points
- * suitable for determining X at the specified Y.
- *
- * @param points Points to use for interpolation.
- * @param startIdx Index within points from which to start search for
- * interpolation bounds points.
- * @param idxStep Index step for search for interpolation bounds points.
- * @param y Y value for which X should be determined.
- * @return the array containing two points suitable for determining X at
- * the specified Y.
- * @throws ZeroException if {@code idxStep} is 0.
- * @throws OutOfRangeException if specified {@code y} is not within the
- * range of the specified {@code points}.
- */
- private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
- int startIdx,
- int idxStep,
- double y)
- throws OutOfRangeException {
- if (idxStep == 0) {
- throw new ZeroException();
- }
- for (int i = startIdx;
- idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
- i += idxStep) {
- final WeightedObservedPoint p1 = points[i];
- final WeightedObservedPoint p2 = points[i + idxStep];
- if (isBetween(y, p1.getY(), p2.getY())) {
- if (idxStep < 0) {
- return new WeightedObservedPoint[] { p2, p1 };
- } else {
- return new WeightedObservedPoint[] { p1, p2 };
- }
- }
- }
-
- // Boundaries are replaced by dummy values because the raised
- // exception is caught and the message never displayed.
- // TODO: Exceptions should not be used for flow control.
- throw new OutOfRangeException(y,
- Double.NEGATIVE_INFINITY,
- Double.POSITIVE_INFINITY);
- }
-
- /**
- * Determines whether a value is between two other values.
- *
- * @param value Value to test whether it is between {@code boundary1}
- * and {@code boundary2}.
- * @param boundary1 One end of the range.
- * @param boundary2 Other end of the range.
- * @return {@code true} if {@code value} is between {@code boundary1} and
- * {@code boundary2} (inclusive), {@code false} otherwise.
- */
- private boolean isBetween(double value,
- double boundary1,
- double boundary2) {
- return (value >= boundary1 && value <= boundary2) ||
- (value >= boundary2 && value <= boundary1);
- }
}
}
diff --git a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitter.java b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitter.java
index b1b0af3..51c6b67 100644
--- a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitter.java
+++ b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitter.java
@@ -25,9 +25,6 @@ import org.apache.commons.math4.legacy.exception.MathIllegalStateException;
import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
import org.apache.commons.math4.legacy.exception.ZeroException;
import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
-import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder;
-import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
-import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
import org.apache.commons.math4.legacy.util.FastMath;
/**
@@ -46,13 +43,9 @@ import org.apache.commons.math4.legacy.util.FastMath;
*
* @since 3.3
*/
-public class HarmonicCurveFitter extends AbstractCurveFitter {
+public class HarmonicCurveFitter extends SimpleCurveFitter {
/** Parametric function to be fitted. */
private static final HarmonicOscillator.Parametric FUNCTION = new HarmonicOscillator.Parametric();
- /** Initial guess. */
- private final double[] initialGuess;
- /** Maximum number of iterations of the optimization algorithm. */
- private final int maxIter;
/**
* Constructor used by the factory methods.
@@ -63,8 +56,7 @@ public class HarmonicCurveFitter extends AbstractCurveFitter {
*/
private HarmonicCurveFitter(double[] initialGuess,
int maxIter) {
- this.initialGuess = initialGuess;
- this.maxIter = maxIter;
+ super(FUNCTION, initialGuess, new ParameterGuesser(), maxIter);
}
/**
@@ -83,63 +75,6 @@ public class HarmonicCurveFitter extends AbstractCurveFitter {
}
/**
- * Configure the start point (initial guess).
- * @param newStart new start point (initial guess)
- * @return a new instance.
- */
- public HarmonicCurveFitter withStartPoint(double[] newStart) {
- return new HarmonicCurveFitter(newStart.clone(),
- maxIter);
- }
-
- /**
- * Configure the maximum number of iterations.
- * @param newMaxIter maximum number of iterations
- * @return a new instance.
- */
- public HarmonicCurveFitter withMaxIterations(int newMaxIter) {
- return new HarmonicCurveFitter(initialGuess,
- newMaxIter);
- }
-
- /** {@inheritDoc} */
- @Override
- protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
- // Prepare least-squares problem.
- final int len = observations.size();
- final double[] target = new double[len];
- final double[] weights = new double[len];
-
- int i = 0;
- for (WeightedObservedPoint obs : observations) {
- target[i] = obs.getY();
- weights[i] = obs.getWeight();
- ++i;
- }
-
- final AbstractCurveFitter.TheoreticalValuesFunction model
- = new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION,
- observations);
-
- final double[] startPoint = initialGuess != null ?
- initialGuess :
- // Compute estimation.
- new ParameterGuesser(observations).guess();
-
- // Return a new optimizer set up to fit a Gaussian curve to the
- // observed points.
- return new LeastSquaresBuilder().
- maxEvaluations(Integer.MAX_VALUE).
- maxIterations(maxIter).
- start(startPoint).
- target(target).
- weight(new DiagonalMatrix(weights)).
- model(model.getModelFunction(), model.getModelFunctionJacobian()).
- build();
-
- }
-
- /**
* This class guesses harmonic coefficients from a sample.
* <p>The algorithm used to guess the coefficients is as follows:</p>
*
@@ -238,24 +173,22 @@ public class HarmonicCurveFitter extends AbstractCurveFitter {
* estimations, these operations run in \(O(n)\) time, where \(n\) is the
* number of measurements.</p>
*/
- public static class ParameterGuesser {
- /** Amplitude. */
- private final double a;
- /** Angular frequency. */
- private final double omega;
- /** Phase. */
- private final double phi;
-
+ public static class ParameterGuesser extends SimpleCurveFitter.ParameterGuesser {
/**
- * Simple constructor.
+ * {@inheritDoc}
*
- * @param observations Sampled observations.
+ * @return the guessed parameters, in the following order:
+ * <ul>
+ * <li>Amplitude</li>
+ * <li>Angular frequency</li>
+ * <li>Phase</li>
+ * </ul>
* @throws NumberIsTooSmallException if the sample is too short.
* @throws ZeroException if the abscissa range is zero.
* @throws MathIllegalStateException when the guessing procedure cannot
* produce sensible results.
*/
- public ParameterGuesser(Collection<WeightedObservedPoint> observations) {
+ public double[] guess(Collection<WeightedObservedPoint> observations) {
if (observations.size() < 4) {
throw new NumberIsTooSmallException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE,
observations.size(), 4, true);
@@ -265,62 +198,15 @@ public class HarmonicCurveFitter extends AbstractCurveFitter {
= sortObservations(observations).toArray(new WeightedObservedPoint[0]);
final double aOmega[] = guessAOmega(sorted);
- a = aOmega[0];
- omega = aOmega[1];
+ final double a = aOmega[0];
+ final double omega = aOmega[1];
- phi = guessPhi(sorted);
- }
+ final double phi = guessPhi(sorted, omega);
- /**
- * Gets an estimation of the parameters.
- *
- * @return the guessed parameters, in the following order:
- * <ul>
- * <li>Amplitude</li>
- * <li>Angular frequency</li>
- * <li>Phase</li>
- * </ul>
- */
- public double[] guess() {
return new double[] { a, omega, phi };
}
/**
- * Sort the observations with respect to the abscissa.
- *
- * @param unsorted Input observations.
- * @return the input observations, sorted.
- */
- private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
- final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
-
- // Since the samples are almost always already sorted, this
- // method is implemented as an insertion sort that reorders the
- // elements in place. Insertion sort is very efficient in this case.
- WeightedObservedPoint curr = observations.get(0);
- final int len = observations.size();
- for (int j = 1; j < len; j++) {
- WeightedObservedPoint prec = curr;
- curr = observations.get(j);
- if (curr.getX() < prec.getX()) {
- // the current element should be inserted closer to the beginning
- int i = j - 1;
- WeightedObservedPoint mI = observations.get(i);
- while ((i >= 0) && (curr.getX() < mI.getX())) {
- observations.set(i + 1, mI);
- if (i-- != 0) {
- mI = observations.get(i);
- }
- }
- observations.set(i + 1, curr);
- curr = observations.get(j);
- }
- }
-
- return observations;
- }
-
- /**
* Estimate a first guess of the amplitude and angular frequency.
*
* @param observations Observations, sorted w.r.t. abscissa.
@@ -415,9 +301,11 @@ public class HarmonicCurveFitter extends AbstractCurveFitter {
* Estimate a first guess of the phase.
*
* @param observations Observations, sorted w.r.t. abscissa.
+ * @param omega Angular frequency.
* @return the guessed phase.
*/
- private double guessPhi(WeightedObservedPoint[] observations) {
+ private double guessPhi(WeightedObservedPoint[] observations,
+ double omega) {
// initialize the means
double fcMean = 0;
double fsMean = 0;
diff --git a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitter.java b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitter.java
index 325097e..9360b80 100644
--- a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitter.java
+++ b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitter.java
@@ -19,10 +19,6 @@ package org.apache.commons.math4.legacy.fitting;
import java.util.Collection;
import org.apache.commons.math4.legacy.analysis.polynomials.PolynomialFunction;
-import org.apache.commons.math4.legacy.exception.MathInternalError;
-import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder;
-import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
-import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
/**
* Fits points to a {@link
@@ -36,25 +32,19 @@ import org.apache.commons.math4.legacy.linear.DiagonalMatrix;
*
* @since 3.3
*/
-public class PolynomialCurveFitter extends AbstractCurveFitter {
+public class PolynomialCurveFitter extends SimpleCurveFitter {
/** Parametric function to be fitted. */
private static final PolynomialFunction.Parametric FUNCTION = new PolynomialFunction.Parametric();
- /** Initial guess. */
- private final double[] initialGuess;
- /** Maximum number of iterations of the optimization algorithm. */
- private final int maxIter;
/**
* Constructor used by the factory methods.
*
* @param initialGuess Initial guess.
* @param maxIter Maximum number of iterations of the optimization algorithm.
- * @throws MathInternalError if {@code initialGuess} is {@code null}.
*/
private PolynomialCurveFitter(double[] initialGuess,
int maxIter) {
- this.initialGuess = initialGuess;
- this.maxIter = maxIter;
+ super(FUNCTION, initialGuess, null, maxIter);
}
/**
@@ -72,60 +62,4 @@ public class PolynomialCurveFitter extends AbstractCurveFitter {
public static PolynomialCurveFitter create(int degree) {
return new PolynomialCurveFitter(new double[degree + 1], Integer.MAX_VALUE);
}
-
- /**
- * Configure the start point (initial guess).
- * @param newStart new start point (initial guess)
- * @return a new instance.
- */
- public PolynomialCurveFitter withStartPoint(double[] newStart) {
- return new PolynomialCurveFitter(newStart.clone(),
- maxIter);
- }
-
- /**
- * Configure the maximum number of iterations.
- * @param newMaxIter maximum number of iterations
- * @return a new instance.
- */
- public PolynomialCurveFitter withMaxIterations(int newMaxIter) {
- return new PolynomialCurveFitter(initialGuess,
- newMaxIter);
- }
-
- /** {@inheritDoc} */
- @Override
- protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
- // Prepare least-squares problem.
- final int len = observations.size();
- final double[] target = new double[len];
- final double[] weights = new double[len];
-
- int i = 0;
- for (WeightedObservedPoint obs : observations) {
- target[i] = obs.getY();
- weights[i] = obs.getWeight();
- ++i;
- }
-
- final AbstractCurveFitter.TheoreticalValuesFunction model =
- new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations);
-
- if (initialGuess == null) {
- throw new MathInternalError();
- }
-
- // Return a new least squares problem set up to fit a polynomial curve to the
- // observed points.
- return new LeastSquaresBuilder().
- maxEvaluations(Integer.MAX_VALUE).
- maxIterations(maxIter).
- start(initialGuess).
- target(target).
- weight(new DiagonalMatrix(weights)).
- model(model.getModelFunction(), model.getModelFunctionJacobian()).
- build();
-
- }
-
}
diff --git a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/SimpleCurveFitter.java b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/SimpleCurveFitter.java
index 9ad65a4..832168f 100644
--- a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/SimpleCurveFitter.java
+++ b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/fitting/SimpleCurveFitter.java
@@ -16,8 +16,14 @@
*/
package org.apache.commons.math4.legacy.fitting;
+import java.util.Collections;
import java.util.Collection;
+import java.util.Comparator;
+import java.util.List;
+import java.util.ArrayList;
+import org.apache.commons.math4.legacy.exception.ZeroException;
+import org.apache.commons.math4.legacy.exception.OutOfRangeException;
import org.apache.commons.math4.legacy.analysis.ParametricUnivariateFunction;
import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math4.legacy.fitting.leastsquares.LeastSquaresProblem;
@@ -33,6 +39,8 @@ public class SimpleCurveFitter extends AbstractCurveFitter {
private final ParametricUnivariateFunction function;
/** Initial guess for the parameters. */
private final double[] initialGuess;
+ /** Parameter guesser. */
+ private final ParameterGuesser guesser;
/** Maximum number of iterations of the optimization algorithm. */
private final int maxIter;
@@ -42,13 +50,17 @@ public class SimpleCurveFitter extends AbstractCurveFitter {
* @param function Function to fit.
* @param initialGuess Initial guess. Cannot be {@code null}. Its length must
* be consistent with the number of parameters of the {@code function} to fit.
+ * @param guesser Method for providing an initial guess (if {@code initialGuess}
+ * is {@code null}).
* @param maxIter Maximum number of iterations of the optimization algorithm.
*/
- private SimpleCurveFitter(ParametricUnivariateFunction function,
- double[] initialGuess,
- int maxIter) {
+ protected SimpleCurveFitter(ParametricUnivariateFunction function,
+ double[] initialGuess,
+ ParameterGuesser guesser,
+ int maxIter) {
this.function = function;
this.initialGuess = initialGuess;
+ this.guesser = guesser;
this.maxIter = maxIter;
}
@@ -68,7 +80,24 @@ public class SimpleCurveFitter extends AbstractCurveFitter {
*/
public static SimpleCurveFitter create(ParametricUnivariateFunction f,
double[] start) {
- return new SimpleCurveFitter(f, start, Integer.MAX_VALUE);
+ return new SimpleCurveFitter(f, start, null, Integer.MAX_VALUE);
+ }
+
+ /**
+ * Creates a curve fitter.
+ * The maximum number of iterations of the optimization algorithm is set
+ * to {@link Integer#MAX_VALUE}.
+ *
+ * @param f Function to fit.
+ * @param guesser Method for providing an initial guess.
+ * @return a curve fitter.
+ *
+ * @see #withStartPoint(double[])
+ * @see #withMaxIterations(int)
+ */
+ public static SimpleCurveFitter create(ParametricUnivariateFunction f,
+ ParameterGuesser guesser) {
+ return new SimpleCurveFitter(f, null, guesser, Integer.MAX_VALUE);
}
/**
@@ -79,6 +108,7 @@ public class SimpleCurveFitter extends AbstractCurveFitter {
public SimpleCurveFitter withStartPoint(double[] newStart) {
return new SimpleCurveFitter(function,
newStart.clone(),
+ null,
maxIter);
}
@@ -90,6 +120,7 @@ public class SimpleCurveFitter extends AbstractCurveFitter {
public SimpleCurveFitter withMaxIterations(int newMaxIter) {
return new SimpleCurveFitter(function,
initialGuess,
+ guesser,
newMaxIter);
}
@@ -112,14 +143,186 @@ public class SimpleCurveFitter extends AbstractCurveFitter {
= new AbstractCurveFitter.TheoreticalValuesFunction(function,
observations);
+ final double[] startPoint = initialGuess != null ?
+ initialGuess :
+ // Compute estimation.
+ guesser.guess(observations);
+
// Create an optimizer for fitting the curve to the observed points.
return new LeastSquaresBuilder().
maxEvaluations(Integer.MAX_VALUE).
maxIterations(maxIter).
- start(initialGuess).
+ start(startPoint).
target(target).
weight(new DiagonalMatrix(weights)).
model(model.getModelFunction(), model.getModelFunctionJacobian()).
build();
}
+
+ /**
+ * Guesses the parameters.
+ */
+ public static abstract class ParameterGuesser {
+ private final Comparator<WeightedObservedPoint> CMP = new Comparator<WeightedObservedPoint>() {
+ /** {@inheritDoc} */
+ @Override
+ public int compare(WeightedObservedPoint p1,
+ WeightedObservedPoint p2) {
+ if (p1 == null && p2 == null) {
+ return 0;
+ }
+ if (p1 == null) {
+ return -1;
+ }
+ if (p2 == null) {
+ return 1;
+ }
+ int comp = Double.compare(p1.getX(), p2.getX());
+ if (comp != 0) {
+ return comp;
+ }
+ comp = Double.compare(p1.getY(), p2.getY());
+ if (comp != 0) {
+ return comp;
+ }
+ comp = Double.compare(p1.getWeight(), p2.getWeight());
+ if (comp != 0) {
+ return comp;
+ }
+ return 0;
+ }
+ };
+
+ /**
+ * Computes an estimation of the parameters.
+ *
+ * @param obs Observations.
+ * @return the guessed parameters.
+ */
+ public abstract double[] guess(Collection<WeightedObservedPoint> obs);
+
+ /**
+ * Sort the observations.
+ *
+ * @param unsorted Input observations.
+ * @return the input observations, sorted.
+ */
+ protected List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
+ final List<WeightedObservedPoint> observations = new ArrayList<>(unsorted);
+ Collections.sort(observations, CMP);
+ return observations;
+ }
+
+ /**
+ * Finds index of point in specified points with the largest Y.
+ *
+ * @param points Points to search.
+ * @return the index in specified points array.
+ */
+ protected int findMaxY(WeightedObservedPoint[] points) {
+ int maxYIdx = 0;
+ for (int i = 1; i < points.length; i++) {
+ if (points[i].getY() > points[maxYIdx].getY()) {
+ maxYIdx = i;
+ }
+ }
+ return maxYIdx;
+ }
+
+ /**
+ * Interpolates using the specified points to determine X at the
+ * specified Y.
+ *
+ * @param points Points to use for interpolation.
+ * @param startIdx Index within points from which to start the search for
+ * interpolation bounds points.
+ * @param idxStep Index step for searching interpolation bounds points.
+ * @param y Y value for which X should be determined.
+ * @return the value of X for the specified Y.
+ * @throws ZeroException if {@code idxStep} is 0.
+ * @throws OutOfRangeException if specified {@code y} is not within the
+ * range of the specified {@code points}.
+ */
+ protected double interpolateXAtY(WeightedObservedPoint[] points,
+ int startIdx,
+ int idxStep,
+ double y) {
+ if (idxStep == 0) {
+ throw new ZeroException();
+ }
+ final WeightedObservedPoint[] twoPoints
+ = getInterpolationPointsForY(points, startIdx, idxStep, y);
+ final WeightedObservedPoint p1 = twoPoints[0];
+ final WeightedObservedPoint p2 = twoPoints[1];
+ if (p1.getY() == y) {
+ return p1.getX();
+ }
+ if (p2.getY() == y) {
+ return p2.getX();
+ }
+ return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
+ (p2.getY() - p1.getY()));
+ }
+
+ /**
+ * Gets the two bounding interpolation points from the specified points
+ * suitable for determining X at the specified Y.
+ *
+ * @param points Points to use for interpolation.
+ * @param startIdx Index within points from which to start search for
+ * interpolation bounds points.
+ * @param idxStep Index step for search for interpolation bounds points.
+ * @param y Y value for which X should be determined.
+ * @return the array containing two points suitable for determining X at
+ * the specified Y.
+ * @throws ZeroException if {@code idxStep} is 0.
+ * @throws OutOfRangeException if specified {@code y} is not within the
+ * range of the specified {@code points}.
+ */
+ private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
+ int startIdx,
+ int idxStep,
+ double y) {
+ if (idxStep == 0) {
+ throw new ZeroException();
+ }
+ for (int i = startIdx;
+ idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
+ i += idxStep) {
+ final WeightedObservedPoint p1 = points[i];
+ final WeightedObservedPoint p2 = points[i + idxStep];
+ if (isBetween(y, p1.getY(), p2.getY())) {
+ if (idxStep < 0) {
+ return new WeightedObservedPoint[] { p2, p1 };
+ } else {
+ return new WeightedObservedPoint[] { p1, p2 };
+ }
+ }
+ }
+
+ // Boundaries are replaced by dummy values because the raised
+ // exception is caught and the message never displayed.
+ // TODO: Exceptions should not be used for flow control.
+ throw new OutOfRangeException(y,
+ Double.NEGATIVE_INFINITY,
+ Double.POSITIVE_INFINITY);
+ }
+
+ /**
+ * Determines whether a value is between two other values.
+ *
+ * @param value Value to test whether it is between {@code boundary1}
+ * and {@code boundary2}.
+ * @param boundary1 One end of the range.
+ * @param boundary2 Other end of the range.
+ * @return {@code true} if {@code value} is between {@code boundary1} and
+ * {@code boundary2} (inclusive), {@code false} otherwise.
+ */
+ private boolean isBetween(double value,
+ double boundary1,
+ double boundary2) {
+ return (value >= boundary1 && value <= boundary2) ||
+ (value >= boundary2 && value <= boundary1);
+ }
+ }
}
diff --git a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitterTest.java b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitterTest.java
index 7ce5a00..94cbe24 100644
--- a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitterTest.java
+++ b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/GaussianCurveFitterTest.java
@@ -180,7 +180,7 @@ public class GaussianCurveFitterTest {
*/
@Test
public void testFit01() {
- GaussianCurveFitter fitter = GaussianCurveFitter.create();
+ SimpleCurveFitter fitter = GaussianCurveFitter.create();
double[] parameters = fitter.fit(createDataset(DATASET1).toList());
Assert.assertEquals(3496978.1837704973, parameters[0], 1e-7);
@@ -190,7 +190,7 @@ public class GaussianCurveFitterTest {
@Test
public void testDataset1LargeXShift() {
- final GaussianCurveFitter fitter = GaussianCurveFitter.create();
+ final SimpleCurveFitter fitter = GaussianCurveFitter.create();
final double xShift = 1e8;
final double[] parameters = fitter.fit(createDataset(DATASET1, xShift, 0).toList());
@@ -204,7 +204,7 @@ public class GaussianCurveFitterTest {
final int maxIter = 20;
final double[] init = { 3.5e6, 4.2, 0.1 };
- GaussianCurveFitter fitter = GaussianCurveFitter.create();
+ SimpleCurveFitter fitter = GaussianCurveFitter.create();
double[] parameters = fitter
.withMaxIterations(maxIter)
.withStartPoint(init)
@@ -220,7 +220,7 @@ public class GaussianCurveFitterTest {
final int maxIter = 1; // Too few iterations.
final double[] init = { 3.5e6, 4.2, 0.1 };
- GaussianCurveFitter fitter = GaussianCurveFitter.create();
+ SimpleCurveFitter fitter = GaussianCurveFitter.create();
fitter.withMaxIterations(maxIter)
.withStartPoint(init)
.fit(createDataset(DATASET1).toList());
@@ -230,7 +230,7 @@ public class GaussianCurveFitterTest {
public void testWithStartPoint() {
final double[] init = { 3.5e6, 4.2, 0.1 };
- GaussianCurveFitter fitter = GaussianCurveFitter.create();
+ SimpleCurveFitter fitter = GaussianCurveFitter.create();
double[] parameters = fitter
.withStartPoint(init)
.fit(createDataset(DATASET1).toList());
@@ -253,7 +253,7 @@ public class GaussianCurveFitterTest {
*/
@Test(expected=MathIllegalArgumentException.class)
public void testFit03() {
- GaussianCurveFitter fitter = GaussianCurveFitter.create();
+ SimpleCurveFitter fitter = GaussianCurveFitter.create();
fitter.fit(createDataset(new double[][] {
{4.0254623, 531026.0},
{4.02804905, 664002.0}
@@ -265,7 +265,7 @@ public class GaussianCurveFitterTest {
*/
@Test
public void testFit04() {
- GaussianCurveFitter fitter = GaussianCurveFitter.create();
+ SimpleCurveFitter fitter = GaussianCurveFitter.create();
double[] parameters = fitter.fit(createDataset(DATASET2).toList());
Assert.assertEquals(233003.2967252038, parameters[0], 1e-4);
@@ -278,7 +278,7 @@ public class GaussianCurveFitterTest {
*/
@Test
public void testFit05() {
- GaussianCurveFitter fitter = GaussianCurveFitter.create();
+ SimpleCurveFitter fitter = GaussianCurveFitter.create();
double[] parameters = fitter.fit(createDataset(DATASET3).toList());
Assert.assertEquals(283863.81929180305, parameters[0], 1e-4);
@@ -291,7 +291,7 @@ public class GaussianCurveFitterTest {
*/
@Test
public void testFit06() {
- GaussianCurveFitter fitter = GaussianCurveFitter.create();
+ SimpleCurveFitter fitter = GaussianCurveFitter.create();
double[] parameters = fitter.fit(createDataset(DATASET4).toList());
Assert.assertEquals(285250.66754309234, parameters[0], 1e-4);
@@ -304,7 +304,7 @@ public class GaussianCurveFitterTest {
*/
@Test
public void testFit07() {
- GaussianCurveFitter fitter = GaussianCurveFitter.create();
+ SimpleCurveFitter fitter = GaussianCurveFitter.create();
double[] parameters = fitter.fit(createDataset(DATASET5).toList());
Assert.assertEquals(3514384.729342235, parameters[0], 1e-4);
diff --git a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitterTest.java b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitterTest.java
index c044a6a..05c3f82 100644
--- a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitterTest.java
+++ b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/HarmonicCurveFitterTest.java
@@ -49,7 +49,7 @@ public class HarmonicCurveFitterTest {
points.add(1, x, f.value(x));
}
- final HarmonicCurveFitter fitter = HarmonicCurveFitter.create();
+ final SimpleCurveFitter fitter = HarmonicCurveFitter.create();
final double[] fitted = fitter.fit(points.toList());
Assert.assertEquals(a, fitted[0], 1.0e-13);
Assert.assertEquals(w, fitted[1], 1.0e-13);
@@ -74,7 +74,7 @@ public class HarmonicCurveFitterTest {
points.add(1, x, f.value(x) + 0.01 * randomizer.nextGaussian());
}
- final HarmonicCurveFitter fitter = HarmonicCurveFitter.create();
+ final SimpleCurveFitter fitter = HarmonicCurveFitter.create();
final double[] fitted = fitter.fit(points.toList());
Assert.assertEquals(a, fitted[0], 7.6e-4);
Assert.assertEquals(w, fitted[1], 2.7e-3);
@@ -90,7 +90,7 @@ public class HarmonicCurveFitterTest {
points.add(1, x, 1e-7 * randomizer.nextGaussian());
}
- final HarmonicCurveFitter fitter = HarmonicCurveFitter.create();
+ final SimpleCurveFitter fitter = HarmonicCurveFitter.create();
fitter.fit(points.toList());
// This test serves to cover the part of the code of "guessAOmega"
@@ -110,7 +110,7 @@ public class HarmonicCurveFitterTest {
points.add(1, x, f.value(x) + 0.01 * randomizer.nextGaussian());
}
- final HarmonicCurveFitter fitter = HarmonicCurveFitter.create()
+ final SimpleCurveFitter fitter = HarmonicCurveFitter.create()
.withStartPoint(new double[] { 0.15, 3.6, 4.5 });
final double[] fitted = fitter.fit(points.toList());
Assert.assertEquals(a, fitted[0], 1.2e-3);
@@ -153,7 +153,7 @@ public class HarmonicCurveFitterTest {
points.add(1, xTab[i], yTab[i]);
}
- final HarmonicCurveFitter fitter = HarmonicCurveFitter.create();
+ final SimpleCurveFitter fitter = HarmonicCurveFitter.create();
final double[] fitted = fitter.fit(points.toList());
Assert.assertEquals(a, fitted[0], 7.6e-4);
Assert.assertEquals(w, fitted[1], 3.5e-3);
@@ -177,6 +177,6 @@ public class HarmonicCurveFitterTest {
// and period 12, and all sample points are taken at integer abscissae
// so function values all belong to the integer subset {-3, -2, -1, 0,
// 1, 2, 3}.
- new HarmonicCurveFitter.ParameterGuesser(points);
+ new HarmonicCurveFitter.ParameterGuesser().guess(points);
}
}
diff --git a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitterTest.java b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitterTest.java
index 5004eb2..c540e24 100644
--- a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitterTest.java
+++ b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/fitting/PolynomialCurveFitterTest.java
@@ -48,7 +48,7 @@ public class PolynomialCurveFitterTest {
}
// Start fit from initial guesses that are far from the optimal values.
- final PolynomialCurveFitter fitter
+ final SimpleCurveFitter fitter
= PolynomialCurveFitter.create(0).withStartPoint(new double[] { -1e-20, 3e15, -5e25 });
final double[] best = fitter.fit(obs.toList());
@@ -60,7 +60,7 @@ public class PolynomialCurveFitterTest {
final Random randomizer = new Random(64925784252l);
for (int degree = 1; degree < 10; ++degree) {
final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
- final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(degree);
+ final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
final WeightedObservedPoints obs = new WeightedObservedPoints();
for (int i = 0; i <= degree; ++i) {
@@ -83,7 +83,7 @@ public class PolynomialCurveFitterTest {
double maxError = 0;
for (int degree = 0; degree < 10; ++degree) {
final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
- final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(degree);
+ final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
final WeightedObservedPoints obs = new WeightedObservedPoints();
for (double x = -1.0; x < 1.0; x += 0.01) {
@@ -114,7 +114,7 @@ public class PolynomialCurveFitterTest {
double maxError = 0;
for (int degree = 0; degree < 10; ++degree) {
final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
- final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(degree);
+ final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
final WeightedObservedPoints obs = new WeightedObservedPoints();
for (int i = 0; i < 40000; ++i) {
@@ -138,7 +138,7 @@ public class PolynomialCurveFitterTest {
for (int degree = 0; degree < 10; ++degree) {
final PolynomialFunction p = buildRandomPolynomial(degree, randomizer);
- final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(degree);
+ final SimpleCurveFitter fitter = PolynomialCurveFitter.create(degree);
final WeightedObservedPoints obs = new WeightedObservedPoints();
// reusing the same point over and over again does not bring
// information, the problem cannot be solved in this case for