You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by er...@apache.org on 2011/02/28 22:32:33 UTC
svn commit: r1075544 - in /commons/proper/math/trunk/src:
main/java/org/apache/commons/math/analysis/function/Logistic.java
test/java/org/apache/commons/math/analysis/function/LogisticTest.java
Author: erans
Date: Mon Feb 28 21:32:33 2011
New Revision: 1075544
URL: http://svn.apache.org/viewvc?rev=1075544&view=rev
Log:
MATH-503
Added derivative function.
Modified:
commons/proper/math/trunk/src/main/java/org/apache/commons/math/analysis/function/Logistic.java
commons/proper/math/trunk/src/test/java/org/apache/commons/math/analysis/function/LogisticTest.java
Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math/analysis/function/Logistic.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math/analysis/function/Logistic.java?rev=1075544&r1=1075543&r2=1075544&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math/analysis/function/Logistic.java (original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math/analysis/function/Logistic.java Mon Feb 28 21:32:33 2011
@@ -18,6 +18,7 @@
package org.apache.commons.math.analysis.function;
import org.apache.commons.math.analysis.UnivariateRealFunction;
+import org.apache.commons.math.analysis.DifferentiableUnivariateRealFunction;
import org.apache.commons.math.exception.NotStrictlyPositiveException;
import org.apache.commons.math.util.FastMath;
@@ -28,7 +29,7 @@ import org.apache.commons.math.util.Fast
* @version $Revision$ $Date$
* @since 3.0
*/
-public class Logistic implements UnivariateRealFunction {
+public class Logistic implements DifferentiableUnivariateRealFunction {
/** Lower asymptote. */
private final double a;
/** Upper asymptote. */
@@ -36,7 +37,7 @@ public class Logistic implements Univari
/** Growth rate. */
private final double b;
/** Parameter that affects near which asymptote maximum growth occurs. */
- private final double n;
+ private final double oneOverN;
/** Parameter that affects the position of the curve along the ordinate axis. */
private final double q;
/** Abscissa of maximum growth. */
@@ -70,11 +71,27 @@ public class Logistic implements Univari
this.b = b;
this.q = q;
this.a = a;
- this.n = n;
+ oneOverN = 1 / n;
}
/** {@inheritDoc} */
public double value(double x) {
- return a + (k - a) / FastMath.pow(1 + q * FastMath.exp(b * (m - x)), 1 / n);
+ return a + (k - a) / FastMath.pow(1 + q * FastMath.exp(b * (m - x)), oneOverN);
+ }
+
+ /** {@inheritDoc} */
+ public UnivariateRealFunction derivative() {
+ return new UnivariateRealFunction() {
+ /** {@inheritDoc} */
+ public double value(double x) {
+ final double exp = q * FastMath.exp(b * (m - x));
+ if (Double.isInfinite(exp)) {
+ // Avoid returning NaN in case of overflow.
+ return 0;
+ }
+ final double exp1 = exp + 1;
+ return b * oneOverN * exp / FastMath.pow(exp1, oneOverN + 1);
+ }
+ };
}
}
Modified: commons/proper/math/trunk/src/test/java/org/apache/commons/math/analysis/function/LogisticTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math/analysis/function/LogisticTest.java?rev=1075544&r1=1075543&r2=1075544&view=diff
==============================================================================
--- commons/proper/math/trunk/src/test/java/org/apache/commons/math/analysis/function/LogisticTest.java (original)
+++ commons/proper/math/trunk/src/test/java/org/apache/commons/math/analysis/function/LogisticTest.java Mon Feb 28 21:32:33 2011
@@ -76,4 +76,25 @@ public class LogisticTest {
x = Double.POSITIVE_INFINITY;
Assert.assertEquals("x=" + x, k, f.value(x), EPS);
}
+
+ @Test
+ public void testCompareDerivativeSigmoid() {
+ final double k = 3;
+ final double a = 2;
+
+ final Logistic f = new Logistic(k, 0, 1, 1, a, 1);
+ final Sigmoid g = new Sigmoid(a, k);
+
+ final UnivariateRealFunction dfdx = f.derivative();
+ final UnivariateRealFunction dgdx = g.derivative();
+
+ final double min = -10;
+ final double max = 10;
+ final double n = 20;
+ final double delta = (max - min) / n;
+ for (int i = 0; i < n; i++) {
+ final double x = min + i * delta;
+ Assert.assertEquals("x=" + x, dgdx.value(x), dfdx.value(x), EPS);
+ }
+ }
}