You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by er...@apache.org on 2014/11/03 11:43:36 UTC
[2/8] git commit: MATH-1144 Allow caller to modify the set of
parameters generated by the optimizer.
MATH-1144
Allow caller to modify the set of parameters generated by the optimizer.
Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/321fd029
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/321fd029
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/321fd029
Branch: refs/heads/master
Commit: 321fd029ec5c9c3c9717f1ede0add49d8a709a01
Parents: f820d06
Author: Gilles <er...@apache.org>
Authored: Mon Oct 13 19:43:26 2014 +0200
Committer: Gilles <er...@apache.org>
Committed: Mon Oct 13 19:43:26 2014 +0200
----------------------------------------------------------------------
.../leastsquares/LeastSquaresBuilder.java | 46 ++++++++++++++++-
.../leastsquares/LeastSquaresFactory.java | 54 ++++++++++++++------
.../leastsquares/ValueAndJacobianFunction.java | 2 +-
.../fitting/leastsquares/EvaluationTest.java | 8 +--
.../LevenbergMarquardtOptimizerTest.java | 42 +++++++++++++++
5 files changed, 129 insertions(+), 23 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/commons-math/blob/321fd029/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java
index 7d3ccbb..7b14b37 100644
--- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java
+++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java
@@ -47,6 +47,17 @@ public class LeastSquaresBuilder {
private RealVector start;
/** weight matrix */
private RealMatrix weight;
+ /**
+ * Lazy evaluation.
+ *
+ * @since 3.4
+ */
+ private boolean lazyEvaluation;
+ /** Validator.
+ *
+ * @since 3.4
+ */
+ private ParameterValidator paramValidator;
/**
@@ -55,7 +66,15 @@ public class LeastSquaresBuilder {
* @return a new {@link LeastSquaresProblem}.
*/
public LeastSquaresProblem build() {
- return LeastSquaresFactory.create(model, target, start, weight, checker, maxEvaluations, maxIterations);
+ return LeastSquaresFactory.create(model,
+ target,
+ start,
+ weight,
+ checker,
+ maxEvaluations,
+ maxIterations,
+ lazyEvaluation,
+ paramValidator);
}
/**
@@ -179,4 +198,29 @@ public class LeastSquaresBuilder {
return this;
}
+ /**
+ * Configure whether evaluation will be lazy or not.
+ *
+ * @param newValue Whether to perform lazy evaluation.
+ * @return this object.
+ *
+ * @since 3.4
+ */
+ public LeastSquaresBuilder lazyEvaluation(final boolean newValue) {
+ lazyEvaluation = newValue;
+ return this;
+ }
+
+ /**
+ * Configure the validator of the model parameters.
+ *
+ * @param newValidator Parameter validator.
+ * @return this object.
+ *
+ * @since 3.4
+ */
+ public LeastSquaresBuilder parameterValidator(final ParameterValidator newValidator) {
+ paramValidator = newValidator;
+ return this;
+ }
}
http://git-wip-us.apache.org/repos/asf/commons-math/blob/321fd029/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java
index 917acfc..1a92ac9 100644
--- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java
+++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java
@@ -56,22 +56,33 @@ public class LeastSquaresFactory {
* @param maxIterations the maximum number to times to iterate in the algorithm
* @param lazyEvaluation Whether the call to {@link Evaluation#evaluate(RealVector)}
* will defer the evaluation until access to the value is requested.
+ * @param paramValidator Model parameters validator.
* @return the specified General Least Squares problem.
+ *
+ * @since 3.4
*/
public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
final RealVector observed,
final RealVector start,
+ final RealMatrix weight,
final ConvergenceChecker<Evaluation> checker,
final int maxEvaluations,
final int maxIterations,
- final boolean lazyEvaluation) {
- return new LocalLeastSquaresProblem(model,
- observed,
- start,
- checker,
- maxEvaluations,
- maxIterations,
- lazyEvaluation);
+ final boolean lazyEvaluation,
+ final ParameterValidator paramValidator) {
+ final LeastSquaresProblem p = new LocalLeastSquaresProblem(model,
+ observed,
+ start,
+ checker,
+ maxEvaluations,
+ maxIterations,
+ lazyEvaluation,
+ paramValidator);
+ if (weight != null) {
+ return weightMatrix(p, weight);
+ } else {
+ return p;
+ }
}
/**
@@ -92,13 +103,15 @@ public class LeastSquaresFactory {
final ConvergenceChecker<Evaluation> checker,
final int maxEvaluations,
final int maxIterations) {
- return new LocalLeastSquaresProblem(model,
- observed,
- start,
- checker,
- maxEvaluations,
- maxIterations,
- false);
+ return create(model,
+ observed,
+ start,
+ null,
+ checker,
+ maxEvaluations,
+ maxIterations,
+ false,
+ null);
}
/**
@@ -345,6 +358,8 @@ public class LeastSquaresFactory {
private final RealVector start;
/** Whether to use lazy evaluation. */
private final boolean lazyEvaluation;
+ /** Model parameters validator. */
+ private final ParameterValidator paramValidator;
/**
* Create a {@link LeastSquaresProblem} from the given data.
@@ -357,6 +372,7 @@ public class LeastSquaresFactory {
* @param maxIterations the allowed iterations
* @param lazyEvaluation Whether the call to {@link Evaluation#evaluate(RealVector)}
* will defer the evaluation until access to the value is requested.
+ * @param paramValidator Model parameters validator.
*/
LocalLeastSquaresProblem(final MultivariateJacobianFunction model,
final RealVector target,
@@ -364,12 +380,14 @@ public class LeastSquaresFactory {
final ConvergenceChecker<Evaluation> checker,
final int maxEvaluations,
final int maxIterations,
- boolean lazyEvaluation) {
+ final boolean lazyEvaluation,
+ final ParameterValidator paramValidator) {
super(maxEvaluations, maxIterations, checker);
this.target = target;
this.model = model;
this.start = start;
this.lazyEvaluation = lazyEvaluation;
+ this.paramValidator = paramValidator;
if (lazyEvaluation &&
!(model instanceof ValueAndJacobianFunction)) {
@@ -398,7 +416,9 @@ public class LeastSquaresFactory {
/** {@inheritDoc} */
public Evaluation evaluate(final RealVector point) {
// Copy so optimizer can change point without changing our instance.
- final RealVector p = point.copy();
+ final RealVector p = paramValidator == null ?
+ point.copy() :
+ paramValidator.validate(point.copy());
if (lazyEvaluation) {
return new LazyUnweightedEvaluation((ValueAndJacobianFunction) model,
http://git-wip-us.apache.org/repos/asf/commons-math/blob/321fd029/src/main/java/org/apache/commons/math3/fitting/leastsquares/ValueAndJacobianFunction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/ValueAndJacobianFunction.java b/src/main/java/org/apache/commons/math3/fitting/leastsquares/ValueAndJacobianFunction.java
index 39e7ae4..180e328 100644
--- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/ValueAndJacobianFunction.java
+++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/ValueAndJacobianFunction.java
@@ -23,7 +23,7 @@ import org.apache.commons.math3.linear.RealVector;
* A interface for functions that compute a vector of values and can compute their
* derivatives (Jacobian).
*
- * @since 3.3
+ * @since 3.4
*/
public interface ValueAndJacobianFunction extends MultivariateJacobianFunction {
/**
http://git-wip-us.apache.org/repos/asf/commons-math/blob/321fd029/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTest.java b/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTest.java
index 9cfbe0b..a53b3f7 100644
--- a/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTest.java
+++ b/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTest.java
@@ -226,7 +226,7 @@ public class EvaluationTest {
final LeastSquaresProblem p
= LeastSquaresFactory.create(LeastSquaresFactory.model(dummyModel(), dummyJacobian()),
- dummy, dummy, null, 0, 0, true);
+ dummy, dummy, null, null, 0, 0, true, null);
// Should not throw because actual evaluation is deferred.
final Evaluation eval = p.evaluate(dummy);
@@ -263,7 +263,7 @@ public class EvaluationTest {
try {
// Should throw.
- LeastSquaresFactory.create(m1, dummy, dummy, null, 0, 0, true);
+ LeastSquaresFactory.create(m1, dummy, dummy, null, null, 0, 0, true, null);
Assert.fail("Expecting MathIllegalStateException");
} catch (MathIllegalStateException e) {
// Expected.
@@ -282,7 +282,7 @@ public class EvaluationTest {
};
// Should pass.
- LeastSquaresFactory.create(m2, dummy, dummy, null, 0, 0, true);
+ LeastSquaresFactory.create(m2, dummy, dummy, null, null, 0, 0, true, null);
}
@Test
@@ -291,7 +291,7 @@ public class EvaluationTest {
final LeastSquaresProblem p
= LeastSquaresFactory.create(LeastSquaresFactory.model(dummyModel(), dummyJacobian()),
- dummy, dummy, null, 0, 0, false);
+ dummy, dummy, null, null, 0, 0, false, null);
try {
// Should throw.
http://git-wip-us.apache.org/repos/asf/commons-math/blob/321fd029/src/test/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizerTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizerTest.java b/src/test/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizerTest.java
index b2c8f54..46658db 100644
--- a/src/test/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizerTest.java
+++ b/src/test/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizerTest.java
@@ -268,6 +268,48 @@ public class LevenbergMarquardtOptimizerTest
}
@Test
+ public void testParameterValidator() {
+ // Setup.
+ final double xCenter = 123.456;
+ final double yCenter = 654.321;
+ final double xSigma = 10;
+ final double ySigma = 15;
+ final double radius = 111.111;
+ final long seed = 3456789L;
+ final RandomCirclePointGenerator factory
+ = new RandomCirclePointGenerator(xCenter, yCenter, radius,
+ xSigma, ySigma,
+ seed);
+ final CircleProblem circle = new CircleProblem(xSigma, ySigma);
+
+ final int numPoints = 10;
+ for (Vector2D p : factory.generate(numPoints)) {
+ circle.addPoint(p.getX(), p.getY());
+ }
+
+ // First guess for the center's coordinates and radius.
+ final double[] init = { 90, 659, 115 };
+ final Optimum optimum
+ = optimizer.optimize(builder(circle).maxIterations(50).start(init).build());
+ final int numEval = optimum.getEvaluations();
+ Assert.assertTrue(numEval > 1);
+
+ // Build a new problem with an validator that amounts to cheating.
+ final ParameterValidator cheatValidator
+ = new ParameterValidator() {
+ public RealVector validate(RealVector params) {
+ // Cheat: return the optimum found previously.
+ return optimum.getPoint();
+ }
+ };
+
+ final Optimum cheatOptimum
+ = optimizer.optimize(builder(circle).maxIterations(50).start(init).parameterValidator(cheatValidator).build());
+ final int cheatNumEval = cheatOptimum.getEvaluations();
+ Assert.assertTrue(cheatNumEval < numEval);
+ }
+
+ @Test
public void testEvaluationCount() {
//setup
LeastSquaresProblem lsp = new LinearProblem(new double[][] {{1}}, new double[] {1})