You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by lu...@apache.org on 2015/05/03 19:18:32 UTC

[math] Converters for univariate and multivariate differentiable functions.

Repository: commons-math
Updated Branches:
  refs/heads/master cb21480cb -> 613afdb0c


Converters for univariate and multivariate differentiable functions.

JIRA: MATH-1143


Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/613afdb0
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/613afdb0
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/613afdb0

Branch: refs/heads/master
Commit: 613afdb0c33b9754f049a295ba9f7351de04a748
Parents: cb21480
Author: Luc Maisonobe <lu...@apache.org>
Authored: Sun May 3 19:18:09 2015 +0200
Committer: Luc Maisonobe <lu...@apache.org>
Committed: Sun May 3 19:18:09 2015 +0200

----------------------------------------------------------------------
 src/changes/changes.xml                         |   3 +
 .../commons/math4/analysis/FunctionUtils.java   | 204 +++++++++++++++++++
 .../math4/analysis/FunctionUtilsTest.java       | 195 ++++++++++++++++++
 3 files changed, 402 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-math/blob/613afdb0/src/changes/changes.xml
----------------------------------------------------------------------
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index a342025..d39acc0 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -54,6 +54,9 @@ If the output is not quite correct, check for invisible trailing spaces!
     </release>
 
     <release version="4.0" date="XXXX-XX-XX" description="">
+      <action dev="luc" type="fix" issue="MATH-1143">
+        Added helper methods to FunctionUtils for univariate and multivariate differentiable functions conversion.
+      </action>
       <action dev="tn" type="fix" issue="MATH-964">
         Removed unused package private class PollardRho in package primes.
       </action>

http://git-wip-us.apache.org/repos/asf/commons-math/blob/613afdb0/src/main/java/org/apache/commons/math4/analysis/FunctionUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math4/analysis/FunctionUtils.java b/src/main/java/org/apache/commons/math4/analysis/FunctionUtils.java
index dca2e8a..83f7467 100644
--- a/src/main/java/org/apache/commons/math4/analysis/FunctionUtils.java
+++ b/src/main/java/org/apache/commons/math4/analysis/FunctionUtils.java
@@ -18,12 +18,14 @@
 package org.apache.commons.math4.analysis;
 
 import org.apache.commons.math4.analysis.differentiation.DerivativeStructure;
+import org.apache.commons.math4.analysis.differentiation.MultivariateDifferentiableFunction;
 import org.apache.commons.math4.analysis.differentiation.UnivariateDifferentiableFunction;
 import org.apache.commons.math4.analysis.function.Identity;
 import org.apache.commons.math4.exception.DimensionMismatchException;
 import org.apache.commons.math4.exception.NotStrictlyPositiveException;
 import org.apache.commons.math4.exception.NumberIsTooLargeException;
 import org.apache.commons.math4.exception.util.LocalizedFormats;
+import org.apache.commons.math4.util.MathArrays;
 
 /**
  * Utilities for manipulating function objects.
@@ -337,4 +339,206 @@ public class FunctionUtils {
         return s;
     }
 
+    /** Convert regular functions to {@link UnivariateDifferentiableFunction}.
+     * <p>
+     * This method handle the case with one free parameter and several derivatives.
+     * For the case with several free parameters and only first order derivatives,
+     * see {@link #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)}.
+     * There are no direct support for intermediate cases, with several free parameters
+     * and order 2 or more derivatives, as is would be difficult to specify all the
+     * cross derivatives.
+     * </p>
+     * <p>
+     * Note that the derivatives are expected to be computed only with respect to the
+     * raw parameter x of the base function, i.e. they are df/dx, df<sup>2</sup>/dx<sup>2</sup>, ...
+     * Even if the built function is later used in a composition like f(sin(t)), the provided
+     * derivatives should <em>not</em> apply the composition with sine and its derivatives by
+     * themselves. The composition will be done automatically here and the result will properly
+     * contain f(sin(t)), df(sin(t))/dt, df<sup>2</sup>(sin(t))/dt<sup>2</sup> despite the
+     * provided derivatives functions know nothing about the sine function.
+     * </p>
+     * @param f base function f(x)
+     * @param derivatives derivatives of the base function, in increasing differentiation order
+     * @return a differentiable function with value and all specified derivatives
+     * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
+     * @see #derivative(UnivariateDifferentiableFunction, int)
+     */
+    public static UnivariateDifferentiableFunction toDifferentiable(final UnivariateFunction f,
+                                                                       final UnivariateFunction ... derivatives) {
+
+        return new UnivariateDifferentiableFunction() {
+
+            /** {@inheritDoc} */
+            @Override
+            public double value(final double x) {
+                return f.value(x);
+            }
+
+            /** {@inheritDoc} */
+            @Override
+            public DerivativeStructure value(final DerivativeStructure x) {
+                if (x.getOrder() > derivatives.length) {
+                    throw new NumberIsTooLargeException(x.getOrder(), derivatives.length, true);
+                }
+                final double[] packed = new double[x.getOrder() + 1];
+                packed[0] = f.value(x.getValue());
+                for (int i = 0; i < x.getOrder(); ++i) {
+                    packed[i + 1] = derivatives[i].value(x.getValue());
+                }
+                return x.compose(packed);
+            }
+
+        };
+
+    }
+
+    /** Convert regular functions to {@link MultivariateDifferentiableFunction}.
+     * <p>
+     * This method handle the case with several free parameters and only first order derivatives.
+     * For the case with one free parameter and several derivatives,
+     * see {@link #toDifferentiable(UnivariateFunction, UnivariateFunction...)}.
+     * There are no direct support for intermediate cases, with several free parameters
+     * and order 2 or more derivatives, as is would be difficult to specify all the
+     * cross derivatives.
+     * </p>
+     * <p>
+     * Note that the gradient is expected to be computed only with respect to the
+     * raw parameter x of the base function, i.e. it is df/dx<sub>1</sub>, df/dx<sub>2</sub>, ...
+     * Even if the built function is later used in a composition like f(sin(t), cos(t)), the provided
+     * gradient should <em>not</em> apply the composition with sine or cosine and their derivative by
+     * itself. The composition will be done automatically here and the result will properly
+     * contain f(sin(t), cos(t)), df(sin(t), cos(t))/dt despite the provided derivatives functions
+     * know nothing about the sine or cosine functions.
+     * </p>
+     * @param f base function f(x)
+     * @param gradient gradient of the base function
+     * @return a differentiable function with value and gradient
+     * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
+     * @see #derivative(MultivariateDifferentiableFunction, int[])
+     */
+    public static MultivariateDifferentiableFunction toDifferentiable(final MultivariateFunction f,
+                                                                         final MultivariateVectorFunction gradient) {
+
+        return new MultivariateDifferentiableFunction() {
+
+            /** {@inheritDoc} */
+            @Override
+            public double value(final double[] point) {
+                return f.value(point);
+            }
+
+            /** {@inheritDoc} */
+            @Override
+            public DerivativeStructure value(final DerivativeStructure[] point) {
+
+                // set up the input parameters
+                final double[] dPoint = new double[point.length];
+                for (int i = 0; i < point.length; ++i) {
+                    dPoint[i] = point[i].getValue();
+                    if (point[i].getOrder() > 1) {
+                        throw new NumberIsTooLargeException(point[i].getOrder(), 1, true);
+                    }
+                }
+
+                // evaluate regular functions
+                final double    v = f.value(dPoint);
+                final double[] dv = gradient.value(dPoint);
+                if (dv.length != point.length) {
+                    // the gradient function is inconsistent
+                    throw new DimensionMismatchException(dv.length, point.length);
+                }
+
+                // build the combined derivative
+                final int parameters = point[0].getFreeParameters();
+                final double[] partials = new double[point.length];
+                final double[] packed = new double[parameters + 1];
+                packed[0] = v;
+                final int orders[] = new int[parameters];
+                for (int i = 0; i < parameters; ++i) {
+
+                    // we differentiate once with respect to parameter i
+                    orders[i] = 1;
+                    for (int j = 0; j < point.length; ++j) {
+                        partials[j] = point[j].getPartialDerivative(orders);
+                    }
+                    orders[i] = 0;
+
+                    // compose partial derivatives
+                    packed[i + 1] = MathArrays.linearCombination(dv, partials);
+
+                }
+
+                return new DerivativeStructure(parameters, 1, packed);
+
+            }
+
+        };
+
+    }
+
+    /** Convert an {@link UnivariateDifferentiableFunction} to an
+     * {@link UnivariateFunction} computing n<sup>th</sup> order derivative.
+     * <p>
+     * This converter is only a convenience method. Beware computing only one derivative does
+     * not save any computation as the original function will really be called under the hood.
+     * The derivative will be extracted from the full {@link DerivativeStructure} result.
+     * </p>
+     * @param f original function, with value and all its derivatives
+     * @param order of the derivative to extract
+     * @return function computing the derivative at required order
+     * @see #derivative(MultivariateDifferentiableFunction, int[])
+     * @see #toDifferentiable(UnivariateFunction, UnivariateFunction...)
+     */
+    public static UnivariateFunction derivative(final UnivariateDifferentiableFunction f, final int order) {
+        return new UnivariateFunction() {
+
+            /** {@inheritDoc} */
+            @Override
+            public double value(final double x) {
+                final DerivativeStructure dsX = new DerivativeStructure(1, order, 0, x);
+                return f.value(dsX).getPartialDerivative(order);
+            }
+
+        };
+    }
+
+    /** Convert an {@link MultivariateDifferentiableFunction} to an
+     * {@link MultivariateFunction} computing n<sup>th</sup> order derivative.
+     * <p>
+     * This converter is only a convenience method. Beware computing only one derivative does
+     * not save any computation as the original function will really be called under the hood.
+     * The derivative will be extracted from the full {@link DerivativeStructure} result.
+     * </p>
+     * @param f original function, with value and all its derivatives
+     * @param orders of the derivative to extract, for each free parameters
+     * @return function computing the derivative at required order
+     * @see #derivative(UnivariateDifferentiableFunction, int)
+     * @see #toDifferentiable(MultivariateFunction, MultivariateVectorFunction)
+     */
+    public static MultivariateFunction derivative(final MultivariateDifferentiableFunction f, final int[] orders) {
+        return new MultivariateFunction() {
+
+            /** {@inheritDoc} */
+            @Override
+            public double value(final double[] point) {
+
+                // the maximum differentiation order is the sum of all orders
+                int sumOrders = 0;
+                for (final int order : orders) {
+                    sumOrders += order;
+                }
+
+                // set up the input parameters
+                final DerivativeStructure[] dsPoint = new DerivativeStructure[point.length];
+                for (int i = 0; i < point.length; ++i) {
+                    dsPoint[i] = new DerivativeStructure(point.length, sumOrders, i, point[i]);
+                }
+
+                return f.value(dsPoint).getPartialDerivative(orders);
+
+            }
+
+        };
+    }
+
 }

http://git-wip-us.apache.org/repos/asf/commons-math/blob/613afdb0/src/test/java/org/apache/commons/math4/analysis/FunctionUtilsTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/commons/math4/analysis/FunctionUtilsTest.java b/src/test/java/org/apache/commons/math4/analysis/FunctionUtilsTest.java
index bfa4340..ee6be54 100644
--- a/src/test/java/org/apache/commons/math4/analysis/FunctionUtilsTest.java
+++ b/src/test/java/org/apache/commons/math4/analysis/FunctionUtilsTest.java
@@ -18,6 +18,7 @@
 package org.apache.commons.math4.analysis;
 
 import org.apache.commons.math4.analysis.differentiation.DerivativeStructure;
+import org.apache.commons.math4.analysis.differentiation.MultivariateDifferentiableFunction;
 import org.apache.commons.math4.analysis.differentiation.UnivariateDifferentiableFunction;
 import org.apache.commons.math4.analysis.function.Add;
 import org.apache.commons.math4.analysis.function.Constant;
@@ -35,6 +36,7 @@ import org.apache.commons.math4.analysis.function.Pow;
 import org.apache.commons.math4.analysis.function.Power;
 import org.apache.commons.math4.analysis.function.Sin;
 import org.apache.commons.math4.analysis.function.Sinc;
+import org.apache.commons.math4.exception.DimensionMismatchException;
 import org.apache.commons.math4.exception.NotStrictlyPositiveException;
 import org.apache.commons.math4.exception.NumberIsTooLargeException;
 import org.apache.commons.math4.util.FastMath;
@@ -233,4 +235,197 @@ public class FunctionUtilsTest {
         }
     }
 
+    @Test
+    public void testToDifferentiableUnivariate() {
+
+        final UnivariateFunction f0 = new UnivariateFunction() {
+            @Override
+            public double value(final double x) {
+                return x * x;
+            }
+        };
+        final UnivariateFunction f1 = new UnivariateFunction() {
+            @Override
+            public double value(final double x) {
+                return 2 * x;
+            }
+        };
+        final UnivariateFunction f2 = new UnivariateFunction() {
+            @Override
+            public double value(final double x) {
+                return 2;
+            }
+        };
+        final UnivariateDifferentiableFunction f = FunctionUtils.toDifferentiable(f0, f1, f2);
+
+        for (double t = -1.0; t < 1; t += 0.01) {
+            // x = sin(t)
+            DerivativeStructure dsT = new DerivativeStructure(1, 2, 0, t);
+            DerivativeStructure y = f.value(dsT.sin());
+            Assert.assertEquals(FastMath.sin(t) * FastMath.sin(t),               f.value(FastMath.sin(t)),  1.0e-15);
+            Assert.assertEquals(FastMath.sin(t) * FastMath.sin(t),               y.getValue(),              1.0e-15);
+            Assert.assertEquals(2 * FastMath.cos(t) * FastMath.sin(t),           y.getPartialDerivative(1), 1.0e-15);
+            Assert.assertEquals(2 * (1 - 2 * FastMath.sin(t) * FastMath.sin(t)), y.getPartialDerivative(2), 1.0e-15);
+        }
+
+        try {
+            f.value(new DerivativeStructure(1, 3, 0.0));
+            Assert.fail("an exception should have been thrown");
+        } catch (NumberIsTooLargeException e) {
+            Assert.assertEquals(2, e.getMax());
+            Assert.assertEquals(3, e.getArgument());
+        }
+    }
+
+    @Test
+    public void testToDifferentiableMultivariate() {
+
+        final double a = 1.5;
+        final double b = 0.5;
+        final MultivariateFunction f = new MultivariateFunction() {
+            @Override
+            public double value(final double[] point) {
+                return a * point[0] + b * point[1];
+            }
+        };
+        final MultivariateVectorFunction gradient = new MultivariateVectorFunction() {
+            @Override
+            public double[] value(final double[] point) {
+                return new double[] { a, b };
+            }
+        };
+        final MultivariateDifferentiableFunction mdf = FunctionUtils.toDifferentiable(f, gradient);
+
+        for (double t = -1.0; t < 1; t += 0.01) {
+            // x = sin(t), y = cos(t), hence the method really becomes univariate
+            DerivativeStructure dsT = new DerivativeStructure(1, 1, 0, t);
+            DerivativeStructure y = mdf.value(new DerivativeStructure[] { dsT.sin(), dsT.cos() });
+            Assert.assertEquals(a * FastMath.sin(t) + b * FastMath.cos(t), y.getValue(),              1.0e-15);
+            Assert.assertEquals(a * FastMath.cos(t) - b * FastMath.sin(t), y.getPartialDerivative(1), 1.0e-15);
+        }
+
+        for (double u = -1.0; u < 1; u += 0.01) {
+            DerivativeStructure dsU = new DerivativeStructure(2, 1, 0, u);
+            for (double v = -1.0; v < 1; v += 0.01) {
+                DerivativeStructure dsV = new DerivativeStructure(2, 1, 1, v);
+                DerivativeStructure y = mdf.value(new DerivativeStructure[] { dsU, dsV });
+                Assert.assertEquals(a * u + b * v, mdf.value(new double[] { u, v }), 1.0e-15);
+                Assert.assertEquals(a * u + b * v, y.getValue(),                     1.0e-15);
+                Assert.assertEquals(a,             y.getPartialDerivative(1, 0),     1.0e-15);
+                Assert.assertEquals(b,             y.getPartialDerivative(0, 1),     1.0e-15);
+            }
+        }
+
+        try {
+            mdf.value(new DerivativeStructure[] { new DerivativeStructure(1, 3, 0.0), new DerivativeStructure(1, 3, 0.0) });
+            Assert.fail("an exception should have been thrown");
+        } catch (NumberIsTooLargeException e) {
+            Assert.assertEquals(1, e.getMax());
+            Assert.assertEquals(3, e.getArgument());
+        }
+    }
+
+    @Test
+    public void testToDifferentiableMultivariateInconsistentGradient() {
+
+        final double a = 1.5;
+        final double b = 0.5;
+        final MultivariateFunction f = new MultivariateFunction() {
+            @Override
+            public double value(final double[] point) {
+                return a * point[0] + b * point[1];
+            }
+        };
+        final MultivariateVectorFunction gradient = new MultivariateVectorFunction() {
+            @Override
+            public double[] value(final double[] point) {
+                return new double[] { a, b, 0.0 };
+            }
+        };
+        final MultivariateDifferentiableFunction mdf = FunctionUtils.toDifferentiable(f, gradient);
+
+        try {
+            DerivativeStructure dsT = new DerivativeStructure(1, 1, 0, 0.0);
+            mdf.value(new DerivativeStructure[] { dsT.sin(), dsT.cos() });
+            Assert.fail("an exception should have been thrown");
+        } catch (DimensionMismatchException e) {
+            Assert.assertEquals(2, e.getDimension());
+            Assert.assertEquals(3, e.getArgument());
+        }
+    }
+
+    @Test
+    public void testDerivativeUnivariate() {
+
+        final UnivariateDifferentiableFunction f = new UnivariateDifferentiableFunction() {
+            
+            @Override
+            public double value(double x) {
+                return x * x;
+            }
+            
+            @Override
+            public DerivativeStructure value(DerivativeStructure x) {
+                return x.multiply(x);
+            }
+
+        };
+
+        final UnivariateFunction f0 = FunctionUtils.derivative(f, 0);
+        final UnivariateFunction f1 = FunctionUtils.derivative(f, 1);
+        final UnivariateFunction f2 = FunctionUtils.derivative(f, 2);
+
+        for (double t = -1.0; t < 1; t += 0.01) {
+            Assert.assertEquals(t * t, f0.value(t), 1.0e-15);
+            Assert.assertEquals(2 * t, f1.value(t), 1.0e-15);
+            Assert.assertEquals(2,     f2.value(t), 1.0e-15);
+        }
+
+    }
+
+    @Test
+    public void testDerivativeMultivariate() {
+
+        final double a = 1.5;
+        final double b = 0.5;
+        final double c = 0.25;
+        final MultivariateDifferentiableFunction mdf = new MultivariateDifferentiableFunction() {
+            
+            @Override
+            public double value(double[] point) {
+                return a * point[0] * point[0] + b * point[1] * point[1] + c * point[0] * point[1];
+            }
+            
+            @Override
+            public DerivativeStructure value(DerivativeStructure[] point) {
+                DerivativeStructure x  = point[0];
+                DerivativeStructure y  = point[1];
+                DerivativeStructure x2 = x.multiply(x);
+                DerivativeStructure y2 = y.multiply(y);
+                DerivativeStructure xy = x.multiply(y);
+                return x2.multiply(a).add(y2.multiply(b)).add(xy.multiply(c));
+            }
+
+        };
+
+        final MultivariateFunction f       = FunctionUtils.derivative(mdf, new int[] { 0, 0 });
+        final MultivariateFunction dfdx    = FunctionUtils.derivative(mdf, new int[] { 1, 0 });
+        final MultivariateFunction dfdy    = FunctionUtils.derivative(mdf, new int[] { 0, 1 });
+        final MultivariateFunction d2fdx2  = FunctionUtils.derivative(mdf, new int[] { 2, 0 });
+        final MultivariateFunction d2fdy2  = FunctionUtils.derivative(mdf, new int[] { 0, 2 });
+        final MultivariateFunction d2fdxdy = FunctionUtils.derivative(mdf, new int[] { 1, 1 });
+
+        for (double x = -1.0; x < 1; x += 0.01) {
+            for (double y = -1.0; y < 1; y += 0.01) {
+                Assert.assertEquals(a * x * x + b * y * y + c * x * y, f.value(new double[]       { x, y }), 1.0e-15);
+                Assert.assertEquals(2 * a * x + c * y,                 dfdx.value(new double[]    { x, y }), 1.0e-15);
+                Assert.assertEquals(2 * b * y + c * x,                 dfdy.value(new double[]    { x, y }), 1.0e-15);
+                Assert.assertEquals(2 * a,                             d2fdx2.value(new double[]  { x, y }), 1.0e-15);
+                Assert.assertEquals(2 * b,                             d2fdy2.value(new double[]  { x, y }), 1.0e-15);
+                Assert.assertEquals(c,                                 d2fdxdy.value(new double[] { x, y }), 1.0e-15);
+            }
+        }
+
+    }
+
 }