You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by lu...@apache.org on 2015/05/03 19:18:32 UTC
[math] Converters for univariate and multivariate differentiable
functions.
Repository: commons-math
Updated Branches:
refs/heads/master cb21480cb -> 613afdb0c
Converters for univariate and multivariate differentiable functions.
JIRA: MATH-1143
Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/613afdb0
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/613afdb0
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/613afdb0
Branch: refs/heads/master
Commit: 613afdb0c33b9754f049a295ba9f7351de04a748
Parents: cb21480
Author: Luc Maisonobe <lu...@apache.org>
Authored: Sun May 3 19:18:09 2015 +0200
Committer: Luc Maisonobe <lu...@apache.org>
Committed: Sun May 3 19:18:09 2015 +0200
----------------------------------------------------------------------
src/changes/changes.xml | 3 +
.../commons/math4/analysis/FunctionUtils.java | 204 +++++++++++++++++++
.../math4/analysis/FunctionUtilsTest.java | 195 ++++++++++++++++++
3 files changed, 402 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/commons-math/blob/613afdb0/src/changes/changes.xml
----------------------------------------------------------------------
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index a342025..d39acc0 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -54,6 +54,9 @@ If the output is not quite correct, check for invisible trailing spaces!
</release>
<release version="4.0" date="XXXX-XX-XX" description="">
+ <action dev="luc" type="fix" issue="MATH-1143">
+ Added helper methods to FunctionUtils for univariate and multivariate differentiable functions conversion.
+ </action>
<action dev="tn" type="fix" issue="MATH-964">
Removed unused package private class PollardRho in package primes.
</action>
http://git-wip-us.apache.org/repos/asf/commons-math/blob/613afdb0/src/main/java/org/apache/commons/math4/analysis/FunctionUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math4/analysis/FunctionUtils.java b/src/main/java/org/apache/commons/math4/analysis/FunctionUtils.java
index dca2e8a..83f7467 100644
--- a/src/main/java/org/apache/commons/math4/analysis/FunctionUtils.java
+++ b/src/main/java/org/apache/commons/math4/analysis/FunctionUtils.java
@@ -18,12 +18,14 @@
package org.apache.commons.math4.analysis;
import org.apache.commons.math4.analysis.differentiation.DerivativeStructure;
+import org.apache.commons.math4.analysis.differentiation.MultivariateDifferentiableFunction;
import org.apache.commons.math4.analysis.differentiation.UnivariateDifferentiableFunction;
import org.apache.commons.math4.analysis.function.Identity;
import org.apache.commons.math4.exception.DimensionMismatchException;
import org.apache.commons.math4.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.exception.NumberIsTooLargeException;
import org.apache.commons.math4.exception.util.LocalizedFormats;
+import org.apache.commons.math4.util.MathArrays;
/**
* Utilities for manipulating function objects.
@@ -337,4 +339,206 @@ public class FunctionUtils {
return s;
}
+ /** Convert regular functions to {@link UnivariateDifferentiableFunction}.
+ * <p>
+ * This method handle the case with one free parameter and several derivatives.
+ * For the case with several free parameters and only first order derivatives,
+ * see {@link #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)}.
+ * There are no direct support for intermediate cases, with several free parameters
+ * and order 2 or more derivatives, as is would be difficult to specify all the
+ * cross derivatives.
+ * </p>
+ * <p>
+ * Note that the derivatives are expected to be computed only with respect to the
+ * raw parameter x of the base function, i.e. they are df/dx, df<sup>2</sup>/dx<sup>2</sup>, ...
+ * Even if the built function is later used in a composition like f(sin(t)), the provided
+ * derivatives should <em>not</em> apply the composition with sine and its derivatives by
+ * themselves. The composition will be done automatically here and the result will properly
+ * contain f(sin(t)), df(sin(t))/dt, df<sup>2</sup>(sin(t))/dt<sup>2</sup> despite the
+ * provided derivatives functions know nothing about the sine function.
+ * </p>
+ * @param f base function f(x)
+ * @param derivatives derivatives of the base function, in increasing differentiation order
+ * @return a differentiable function with value and all specified derivatives
+ * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
+ * @see #derivative(UnivariateDifferentiableFunction, int)
+ */
+ public static UnivariateDifferentiableFunction toDifferentiable(final UnivariateFunction f,
+ final UnivariateFunction ... derivatives) {
+
+ return new UnivariateDifferentiableFunction() {
+
+ /** {@inheritDoc} */
+ @Override
+ public double value(final double x) {
+ return f.value(x);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public DerivativeStructure value(final DerivativeStructure x) {
+ if (x.getOrder() > derivatives.length) {
+ throw new NumberIsTooLargeException(x.getOrder(), derivatives.length, true);
+ }
+ final double[] packed = new double[x.getOrder() + 1];
+ packed[0] = f.value(x.getValue());
+ for (int i = 0; i < x.getOrder(); ++i) {
+ packed[i + 1] = derivatives[i].value(x.getValue());
+ }
+ return x.compose(packed);
+ }
+
+ };
+
+ }
+
+ /** Convert regular functions to {@link MultivariateDifferentiableFunction}.
+ * <p>
+ * This method handle the case with several free parameters and only first order derivatives.
+ * For the case with one free parameter and several derivatives,
+ * see {@link #toDifferentiable(UnivariateFunction, UnivariateFunction...)}.
+ * There are no direct support for intermediate cases, with several free parameters
+ * and order 2 or more derivatives, as is would be difficult to specify all the
+ * cross derivatives.
+ * </p>
+ * <p>
+ * Note that the gradient is expected to be computed only with respect to the
+ * raw parameter x of the base function, i.e. it is df/dx<sub>1</sub>, df/dx<sub>2</sub>, ...
+ * Even if the built function is later used in a composition like f(sin(t), cos(t)), the provided
+ * gradient should <em>not</em> apply the composition with sine or cosine and their derivative by
+ * itself. The composition will be done automatically here and the result will properly
+ * contain f(sin(t), cos(t)), df(sin(t), cos(t))/dt despite the provided derivatives functions
+ * know nothing about the sine or cosine functions.
+ * </p>
+ * @param f base function f(x)
+ * @param gradient gradient of the base function
+ * @return a differentiable function with value and gradient
+ * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
+ * @see #derivative(MultivariateDifferentiableFunction, int[])
+ */
+ public static MultivariateDifferentiableFunction toDifferentiable(final MultivariateFunction f,
+ final MultivariateVectorFunction gradient) {
+
+ return new MultivariateDifferentiableFunction() {
+
+ /** {@inheritDoc} */
+ @Override
+ public double value(final double[] point) {
+ return f.value(point);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public DerivativeStructure value(final DerivativeStructure[] point) {
+
+ // set up the input parameters
+ final double[] dPoint = new double[point.length];
+ for (int i = 0; i < point.length; ++i) {
+ dPoint[i] = point[i].getValue();
+ if (point[i].getOrder() > 1) {
+ throw new NumberIsTooLargeException(point[i].getOrder(), 1, true);
+ }
+ }
+
+ // evaluate regular functions
+ final double v = f.value(dPoint);
+ final double[] dv = gradient.value(dPoint);
+ if (dv.length != point.length) {
+ // the gradient function is inconsistent
+ throw new DimensionMismatchException(dv.length, point.length);
+ }
+
+ // build the combined derivative
+ final int parameters = point[0].getFreeParameters();
+ final double[] partials = new double[point.length];
+ final double[] packed = new double[parameters + 1];
+ packed[0] = v;
+ final int orders[] = new int[parameters];
+ for (int i = 0; i < parameters; ++i) {
+
+ // we differentiate once with respect to parameter i
+ orders[i] = 1;
+ for (int j = 0; j < point.length; ++j) {
+ partials[j] = point[j].getPartialDerivative(orders);
+ }
+ orders[i] = 0;
+
+ // compose partial derivatives
+ packed[i + 1] = MathArrays.linearCombination(dv, partials);
+
+ }
+
+ return new DerivativeStructure(parameters, 1, packed);
+
+ }
+
+ };
+
+ }
+
+ /** Convert an {@link UnivariateDifferentiableFunction} to an
+ * {@link UnivariateFunction} computing n<sup>th</sup> order derivative.
+ * <p>
+ * This converter is only a convenience method. Beware computing only one derivative does
+ * not save any computation as the original function will really be called under the hood.
+ * The derivative will be extracted from the full {@link DerivativeStructure} result.
+ * </p>
+ * @param f original function, with value and all its derivatives
+ * @param order of the derivative to extract
+ * @return function computing the derivative at required order
+ * @see #derivative(MultivariateDifferentiableFunction, int[])
+ * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
+ */
+ public static UnivariateFunction derivative(final UnivariateDifferentiableFunction f, final int order) {
+ return new UnivariateFunction() {
+
+ /** {@inheritDoc} */
+ @Override
+ public double value(final double x) {
+ final DerivativeStructure dsX = new DerivativeStructure(1, order, 0, x);
+ return f.value(dsX).getPartialDerivative(order);
+ }
+
+ };
+ }
+
+ /** Convert an {@link MultivariateDifferentiableFunction} to an
+ * {@link MultivariateFunction} computing n<sup>th</sup> order derivative.
+ * <p>
+ * This converter is only a convenience method. Beware computing only one derivative does
+ * not save any computation as the original function will really be called under the hood.
+ * The derivative will be extracted from the full {@link DerivativeStructure} result.
+ * </p>
+ * @param f original function, with value and all its derivatives
+ * @param orders of the derivative to extract, for each free parameters
+ * @return function computing the derivative at required order
+ * @see #derivative(UnivariateDifferentiableFunction, int)
+ * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
+ */
+ public static MultivariateFunction derivative(final MultivariateDifferentiableFunction f, final int[] orders) {
+ return new MultivariateFunction() {
+
+ /** {@inheritDoc} */
+ @Override
+ public double value(final double[] point) {
+
+ // the maximum differentiation order is the sum of all orders
+ int sumOrders = 0;
+ for (final int order : orders) {
+ sumOrders += order;
+ }
+
+ // set up the input parameters
+ final DerivativeStructure[] dsPoint = new DerivativeStructure[point.length];
+ for (int i = 0; i < point.length; ++i) {
+ dsPoint[i] = new DerivativeStructure(point.length, sumOrders, i, point[i]);
+ }
+
+ return f.value(dsPoint).getPartialDerivative(orders);
+
+ }
+
+ };
+ }
+
}
http://git-wip-us.apache.org/repos/asf/commons-math/blob/613afdb0/src/test/java/org/apache/commons/math4/analysis/FunctionUtilsTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/commons/math4/analysis/FunctionUtilsTest.java b/src/test/java/org/apache/commons/math4/analysis/FunctionUtilsTest.java
index bfa4340..ee6be54 100644
--- a/src/test/java/org/apache/commons/math4/analysis/FunctionUtilsTest.java
+++ b/src/test/java/org/apache/commons/math4/analysis/FunctionUtilsTest.java
@@ -18,6 +18,7 @@
package org.apache.commons.math4.analysis;
import org.apache.commons.math4.analysis.differentiation.DerivativeStructure;
+import org.apache.commons.math4.analysis.differentiation.MultivariateDifferentiableFunction;
import org.apache.commons.math4.analysis.differentiation.UnivariateDifferentiableFunction;
import org.apache.commons.math4.analysis.function.Add;
import org.apache.commons.math4.analysis.function.Constant;
@@ -35,6 +36,7 @@ import org.apache.commons.math4.analysis.function.Pow;
import org.apache.commons.math4.analysis.function.Power;
import org.apache.commons.math4.analysis.function.Sin;
import org.apache.commons.math4.analysis.function.Sinc;
+import org.apache.commons.math4.exception.DimensionMismatchException;
import org.apache.commons.math4.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.exception.NumberIsTooLargeException;
import org.apache.commons.math4.util.FastMath;
@@ -233,4 +235,197 @@ public class FunctionUtilsTest {
}
}
+ @Test
+ public void testToDifferentiableUnivariate() {
+
+ final UnivariateFunction f0 = new UnivariateFunction() {
+ @Override
+ public double value(final double x) {
+ return x * x;
+ }
+ };
+ final UnivariateFunction f1 = new UnivariateFunction() {
+ @Override
+ public double value(final double x) {
+ return 2 * x;
+ }
+ };
+ final UnivariateFunction f2 = new UnivariateFunction() {
+ @Override
+ public double value(final double x) {
+ return 2;
+ }
+ };
+ final UnivariateDifferentiableFunction f = FunctionUtils.toDifferentiable(f0, f1, f2);
+
+ for (double t = -1.0; t < 1; t += 0.01) {
+ // x = sin(t)
+ DerivativeStructure dsT = new DerivativeStructure(1, 2, 0, t);
+ DerivativeStructure y = f.value(dsT.sin());
+ Assert.assertEquals(FastMath.sin(t) * FastMath.sin(t), f.value(FastMath.sin(t)), 1.0e-15);
+ Assert.assertEquals(FastMath.sin(t) * FastMath.sin(t), y.getValue(), 1.0e-15);
+ Assert.assertEquals(2 * FastMath.cos(t) * FastMath.sin(t), y.getPartialDerivative(1), 1.0e-15);
+ Assert.assertEquals(2 * (1 - 2 * FastMath.sin(t) * FastMath.sin(t)), y.getPartialDerivative(2), 1.0e-15);
+ }
+
+ try {
+ f.value(new DerivativeStructure(1, 3, 0.0));
+ Assert.fail("an exception should have been thrown");
+ } catch (NumberIsTooLargeException e) {
+ Assert.assertEquals(2, e.getMax());
+ Assert.assertEquals(3, e.getArgument());
+ }
+ }
+
+ @Test
+ public void testToDifferentiableMultivariate() {
+
+ final double a = 1.5;
+ final double b = 0.5;
+ final MultivariateFunction f = new MultivariateFunction() {
+ @Override
+ public double value(final double[] point) {
+ return a * point[0] + b * point[1];
+ }
+ };
+ final MultivariateVectorFunction gradient = new MultivariateVectorFunction() {
+ @Override
+ public double[] value(final double[] point) {
+ return new double[] { a, b };
+ }
+ };
+ final MultivariateDifferentiableFunction mdf = FunctionUtils.toDifferentiable(f, gradient);
+
+ for (double t = -1.0; t < 1; t += 0.01) {
+ // x = sin(t), y = cos(t), hence the method really becomes univariate
+ DerivativeStructure dsT = new DerivativeStructure(1, 1, 0, t);
+ DerivativeStructure y = mdf.value(new DerivativeStructure[] { dsT.sin(), dsT.cos() });
+ Assert.assertEquals(a * FastMath.sin(t) + b * FastMath.cos(t), y.getValue(), 1.0e-15);
+ Assert.assertEquals(a * FastMath.cos(t) - b * FastMath.sin(t), y.getPartialDerivative(1), 1.0e-15);
+ }
+
+ for (double u = -1.0; u < 1; u += 0.01) {
+ DerivativeStructure dsU = new DerivativeStructure(2, 1, 0, u);
+ for (double v = -1.0; v < 1; v += 0.01) {
+ DerivativeStructure dsV = new DerivativeStructure(2, 1, 1, v);
+ DerivativeStructure y = mdf.value(new DerivativeStructure[] { dsU, dsV });
+ Assert.assertEquals(a * u + b * v, mdf.value(new double[] { u, v }), 1.0e-15);
+ Assert.assertEquals(a * u + b * v, y.getValue(), 1.0e-15);
+ Assert.assertEquals(a, y.getPartialDerivative(1, 0), 1.0e-15);
+ Assert.assertEquals(b, y.getPartialDerivative(0, 1), 1.0e-15);
+ }
+ }
+
+ try {
+ mdf.value(new DerivativeStructure[] { new DerivativeStructure(1, 3, 0.0), new DerivativeStructure(1, 3, 0.0) });
+ Assert.fail("an exception should have been thrown");
+ } catch (NumberIsTooLargeException e) {
+ Assert.assertEquals(1, e.getMax());
+ Assert.assertEquals(3, e.getArgument());
+ }
+ }
+
+ @Test
+ public void testToDifferentiableMultivariateInconsistentGradient() {
+
+ final double a = 1.5;
+ final double b = 0.5;
+ final MultivariateFunction f = new MultivariateFunction() {
+ @Override
+ public double value(final double[] point) {
+ return a * point[0] + b * point[1];
+ }
+ };
+ final MultivariateVectorFunction gradient = new MultivariateVectorFunction() {
+ @Override
+ public double[] value(final double[] point) {
+ return new double[] { a, b, 0.0 };
+ }
+ };
+ final MultivariateDifferentiableFunction mdf = FunctionUtils.toDifferentiable(f, gradient);
+
+ try {
+ DerivativeStructure dsT = new DerivativeStructure(1, 1, 0, 0.0);
+ mdf.value(new DerivativeStructure[] { dsT.sin(), dsT.cos() });
+ Assert.fail("an exception should have been thrown");
+ } catch (DimensionMismatchException e) {
+ Assert.assertEquals(2, e.getDimension());
+ Assert.assertEquals(3, e.getArgument());
+ }
+ }
+
+ @Test
+ public void testDerivativeUnivariate() {
+
+ final UnivariateDifferentiableFunction f = new UnivariateDifferentiableFunction() {
+
+ @Override
+ public double value(double x) {
+ return x * x;
+ }
+
+ @Override
+ public DerivativeStructure value(DerivativeStructure x) {
+ return x.multiply(x);
+ }
+
+ };
+
+ final UnivariateFunction f0 = FunctionUtils.derivative(f, 0);
+ final UnivariateFunction f1 = FunctionUtils.derivative(f, 1);
+ final UnivariateFunction f2 = FunctionUtils.derivative(f, 2);
+
+ for (double t = -1.0; t < 1; t += 0.01) {
+ Assert.assertEquals(t * t, f0.value(t), 1.0e-15);
+ Assert.assertEquals(2 * t, f1.value(t), 1.0e-15);
+ Assert.assertEquals(2, f2.value(t), 1.0e-15);
+ }
+
+ }
+
+ @Test
+ public void testDerivativeMultivariate() {
+
+ final double a = 1.5;
+ final double b = 0.5;
+ final double c = 0.25;
+ final MultivariateDifferentiableFunction mdf = new MultivariateDifferentiableFunction() {
+
+ @Override
+ public double value(double[] point) {
+ return a * point[0] * point[0] + b * point[1] * point[1] + c * point[0] * point[1];
+ }
+
+ @Override
+ public DerivativeStructure value(DerivativeStructure[] point) {
+ DerivativeStructure x = point[0];
+ DerivativeStructure y = point[1];
+ DerivativeStructure x2 = x.multiply(x);
+ DerivativeStructure y2 = y.multiply(y);
+ DerivativeStructure xy = x.multiply(y);
+ return x2.multiply(a).add(y2.multiply(b)).add(xy.multiply(c));
+ }
+
+ };
+
+ final MultivariateFunction f = FunctionUtils.derivative(mdf, new int[] { 0, 0 });
+ final MultivariateFunction dfdx = FunctionUtils.derivative(mdf, new int[] { 1, 0 });
+ final MultivariateFunction dfdy = FunctionUtils.derivative(mdf, new int[] { 0, 1 });
+ final MultivariateFunction d2fdx2 = FunctionUtils.derivative(mdf, new int[] { 2, 0 });
+ final MultivariateFunction d2fdy2 = FunctionUtils.derivative(mdf, new int[] { 0, 2 });
+ final MultivariateFunction d2fdxdy = FunctionUtils.derivative(mdf, new int[] { 1, 1 });
+
+ for (double x = -1.0; x < 1; x += 0.01) {
+ for (double y = -1.0; y < 1; y += 0.01) {
+ Assert.assertEquals(a * x * x + b * y * y + c * x * y, f.value(new double[] { x, y }), 1.0e-15);
+ Assert.assertEquals(2 * a * x + c * y, dfdx.value(new double[] { x, y }), 1.0e-15);
+ Assert.assertEquals(2 * b * y + c * x, dfdy.value(new double[] { x, y }), 1.0e-15);
+ Assert.assertEquals(2 * a, d2fdx2.value(new double[] { x, y }), 1.0e-15);
+ Assert.assertEquals(2 * b, d2fdy2.value(new double[] { x, y }), 1.0e-15);
+ Assert.assertEquals(c, d2fdxdy.value(new double[] { x, y }), 1.0e-15);
+ }
+ }
+
+ }
+
}