You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ce...@apache.org on 2012/05/16 07:36:40 UTC
svn commit: r1339014 -
/commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/GammaDistributionTest.java
Author: celestin
Date: Wed May 16 05:36:40 2012
New Revision: 1339014
URL: http://svn.apache.org/viewvc?rev=1339014&view=rev
Log:
Unit tests for GammaDistribution, based on reference data generated with
Maxima. Solves MATH-753.
Modified:
commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/GammaDistributionTest.java
Modified: commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/GammaDistributionTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/GammaDistributionTest.java?rev=1339014&r1=1339013&r2=1339014&view=diff
==============================================================================
--- commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/GammaDistributionTest.java (original)
+++ commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/GammaDistributionTest.java Wed May 16 05:36:40 2012
@@ -17,7 +17,15 @@
package org.apache.commons.math3.distribution;
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import org.apache.commons.math3.special.Gamma;
+import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
+import org.apache.commons.math3.util.FastMath;
import org.junit.Assert;
import org.junit.Test;
@@ -167,4 +175,155 @@ public class GammaDistributionTest exten
Assert.assertEquals(dist.getNumericalMean(), 1.1d * 4.2d, tol);
Assert.assertEquals(dist.getNumericalVariance(), 1.1d * 4.2d * 4.2d, tol);
}
+
+ public static double density(final double x, final double shape,
+ final double scale) {
+ /*
+ * This is a copy of
+ * double GammaDistribution.density(double)
+ * prior to r1338548.
+ */
+ if (x < 0) {
+ return 0;
+ }
+ return FastMath.pow(x / scale, shape - 1) / scale *
+ FastMath.exp(-x / scale) / FastMath.exp(Gamma.logGamma(shape));
+ }
+
+ /*
+ * MATH-753: large values of x or shape parameter cause density(double) to
+ * overflow. Reference data is generated with the Maxima script
+ * gamma-distribution.mac, which can be found in
+ * src/test/resources/org/apache/commons/math3/distribution.
+ */
+
+ private void doTestMath753(final double shape,
+ final double meanNoOF, final double sdNoOF,
+ final double meanOF, final double sdOF,
+ final String resourceName) throws IOException {
+ final GammaDistribution distribution = new GammaDistribution(shape, 1.0);
+ final SummaryStatistics statOld = new SummaryStatistics();
+ final SummaryStatistics statNewNoOF = new SummaryStatistics();
+ final SummaryStatistics statNewOF = new SummaryStatistics();
+
+ final InputStream resourceAsStream;
+ resourceAsStream = this.getClass().getResourceAsStream(resourceName);
+ Assert.assertNotNull("Could not find resource " + resourceName,
+ resourceAsStream);
+ final BufferedReader in;
+ in = new BufferedReader(new InputStreamReader(resourceAsStream));
+
+ try {
+ for (String line = in.readLine(); line != null; line = in
+ .readLine()) {
+ final String[] tokens = line.split(", ");
+ Assert.assertTrue("expected two floating-point values",
+ tokens.length == 2);
+ final double x = Double.parseDouble(tokens[0]);
+ final String msg = "x = " + x + ", shape = " + shape +
+ ", scale = 1.0";
+ final double expected = Double.parseDouble(tokens[1]);
+ final double ulp = FastMath.ulp(expected);
+ final double actualOld = density(x, shape, 1.0);
+ final double actualNew = distribution.density(x);
+ final double errOld, errNew;
+ errOld = FastMath.abs((actualOld - expected) / ulp);
+ errNew = FastMath.abs((actualNew - expected) / ulp);
+
+ if (Double.isNaN(actualOld) || Double.isInfinite(actualOld)) {
+ Assert.assertFalse(msg, Double.isNaN(actualNew));
+ Assert.assertFalse(msg, Double.isInfinite(actualNew));
+ statNewOF.addValue(errNew);
+ } else {
+ statOld.addValue(errOld);
+ statNewNoOF.addValue(errNew);
+ }
+ }
+ if (statOld.getN() != 0) {
+ /*
+ * If no overflow occurs, check that new implementation is
+ * better than old one.
+ */
+ final StringBuilder sb = new StringBuilder("shape = ");
+ sb.append(shape);
+ sb.append(", scale = 1.0\n");
+ sb.append("Old implementation\n");
+ sb.append("------------------\n");
+ sb.append(statOld.toString());
+ sb.append("New implementation\n");
+ sb.append("------------------\n");
+ sb.append(statNewNoOF.toString());
+ final String msg = sb.toString();
+
+ final double oldMin = statOld.getMin();
+ final double newMin = statNewNoOF.getMin();
+ Assert.assertTrue(msg, newMin <= oldMin);
+
+ final double oldMax = statOld.getMax();
+ final double newMax = statNewNoOF.getMax();
+ Assert.assertTrue(msg, newMax <= oldMax);
+
+ final double oldMean = statOld.getMean();
+ final double newMean = statNewNoOF.getMean();
+ Assert.assertTrue(msg, newMean <= oldMean);
+
+ final double oldSd = statOld.getStandardDeviation();
+ final double newSd = statNewNoOF.getStandardDeviation();
+ Assert.assertTrue(msg, newSd <= oldSd);
+
+ Assert.assertTrue(msg, newMean <= meanNoOF);
+ Assert.assertTrue(msg, newSd <= sdNoOF);
+ }
+ if (statNewOF.getN() != 0) {
+ final double newMean = statNewOF.getMean();
+ final double newSd = statNewOF.getStandardDeviation();
+
+ final StringBuilder sb = new StringBuilder("shape = ");
+ sb.append(shape);
+ sb.append(", scale = 1.0");
+ sb.append(", max. mean error (ulps) = ");
+ sb.append(meanOF);
+ sb.append(", actual mean error (ulps) = ");
+ sb.append(newMean);
+ sb.append(", max. sd of error (ulps) = ");
+ sb.append(sdOF);
+ sb.append(", actual sd of error (ulps) = ");
+ sb.append(newSd);
+ final String msg = sb.toString();
+
+ Assert.assertTrue(msg, newMean <= meanOF);
+ Assert.assertTrue(msg, newSd <= sdOF);
+ }
+ } catch (IOException e) {
+ Assert.fail(e.getMessage());
+ } finally {
+ in.close();
+ }
+ }
+
+
+ @Test
+ public void testMath753Shape1() throws IOException {
+ doTestMath753(1.0, 1.5, 0.5, 0.0, 0.0, "gamma-distribution-shape-1.csv");
+ }
+
+ @Test
+ public void testMath753Shape10() throws IOException {
+ doTestMath753(10.0, 1.0, 1.0, 0.0, 0.0, "gamma-distribution-shape-10.csv");
+ }
+
+ @Test
+ public void testMath753Shape100() throws IOException {
+ doTestMath753(100.0, 1.5, 1.0, 0.0, 0.0, "gamma-distribution-shape-100.csv");
+ }
+
+ @Test
+ public void testMath753Shape142() throws IOException {
+ doTestMath753(142.0, 0.5, 1.5, 40.0, 40.0, "gamma-distribution-shape-142.csv");
+ }
+
+ @Test
+ public void testMath753Shape1000() throws IOException {
+ doTestMath753(1000.0, 1.0, 1.0, 160.0, 220.0, "gamma-distribution-shape-1000.csv");
+ }
}