You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2021/08/06 08:15:11 UTC
[commons-rng] branch master updated: RNG-159: Add sampling test for
ZigguratSampler.NormalizedGaussian
This is an automated email from the ASF dual-hosted git repository.
aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git
The following commit(s) were added to refs/heads/master by this push:
new 7985318 RNG-159: Add sampling test for ZigguratSampler.NormalizedGaussian
7985318 is described below
commit 7985318109cc518077a193c334b11c08f6f72e91
Author: Alex Herbert <ah...@apache.org>
AuthorDate: Fri Aug 6 09:15:08 2021 +0100
RNG-159: Add sampling test for ZigguratSampler.NormalizedGaussian
This test is marked as Ignore as the sampler generates incorrect values
around the mean.
---
.../sampling/distribution/ZigguratSamplerTest.java | 135 +++++++++++++++++++++
1 file changed, 135 insertions(+)
diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ZigguratSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ZigguratSamplerTest.java
index 7275387..2b881f2 100644
--- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ZigguratSamplerTest.java
+++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ZigguratSamplerTest.java
@@ -16,7 +16,12 @@
*/
package org.apache.commons.rng.sampling.distribution;
+import org.junit.Assert;
+import org.junit.Ignore;
import org.junit.Test;
+import java.util.Arrays;
+import org.apache.commons.math3.distribution.NormalDistribution;
+import org.apache.commons.math3.stat.inference.ChiSquareTest;
import org.apache.commons.rng.RestorableUniformRandomProvider;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.RandomAssert;
@@ -72,4 +77,134 @@ public class ZigguratSamplerTest {
final ZigguratSampler.NormalizedGaussian sampler2 = sampler1.withUniformRandomProvider(rng2);
RandomAssert.assertProduceSameSequence(sampler1, sampler2);
}
+
+ /**
+ * Test Gaussian samples using a large number of bins based on uniformly spaced quantiles.
+ * Added for RNG-159.
+ */
+ @Ignore("See RNG-159")
+ @Test
+ public void testGaussianSamplesWithQuantiles() {
+ final int bins = 2000;
+ final NormalDistribution dist = new NormalDistribution(null, 0.0, 1.0);
+ final double[] quantiles = new double[bins];
+ for (int i = 0; i < bins; i++) {
+ quantiles[i] = dist.inverseCumulativeProbability((i + 1.0) / bins);
+ }
+
+ final int samples = 10000000;
+ final long[] observed = new long[bins];
+ final RestorableUniformRandomProvider rng = RandomSource.XO_SHI_RO_128_PP.create(0xabcdefL);
+ final ZigguratSampler.NormalizedGaussian sampler = ZigguratSampler.NormalizedGaussian.of(rng);
+ for (int i = 0; i < samples; i++) {
+ final double x = sampler.sample();
+ final int index = findIndex(quantiles, x);
+ observed[index]++;
+ }
+ final double[] expected = new double[bins];
+ Arrays.fill(expected, 1.0 / bins);
+
+ final ChiSquareTest chiSquareTest = new ChiSquareTest();
+ // Pass if we cannot reject null hypothesis that the distributions are the same.
+ double chi2 = chiSquareTest.chiSquareTest(expected, observed);
+ Assert.assertFalse("Chi-square p-value = " + chi2, chi2 < 0.001);
+
+ // Test around the mean
+ for (final double range : new double[] {0.5, 0.25, 0.1, 0.05}) {
+ final int min = findIndex(quantiles, -range);
+ final int max = findIndex(quantiles, range);
+ final long[] observed2 = Arrays.copyOfRange(observed, min, max + 1);
+ final double[] expected2 = Arrays.copyOfRange(expected, min, max + 1);
+ chi2 = chiSquareTest.chiSquareTest(expected2, observed2);
+ Assert.assertFalse(String.format("(%s <= x < %s) Chi-square p-value = %s",
+ -range, range, chi2), chi2 < 0.001);
+ }
+ }
+
+ /**
+ * Test Gaussian samples using a large number of bins uniformly spaced in a range.
+ * Added for RNG-159.
+ */
+ @Ignore("See RNG-159")
+ @Test
+ public void testGaussianSamplesWithUniformValues() {
+ final int bins = 2000;
+ final double[] values = new double[bins];
+ final double minx = -8;
+ final double maxx = 8;
+ for (int i = 0; i < bins; i++) {
+ values[i] = minx + (maxx - minx) * (i + 1.0) / bins;
+ }
+
+ final int samples = 10000000;
+ final long[] observed = new long[bins];
+ final RestorableUniformRandomProvider rng = RandomSource.XO_SHI_RO_128_PP.create(0xabcdefL);
+ final ZigguratSampler.NormalizedGaussian sampler = ZigguratSampler.NormalizedGaussian.of(rng);
+ for (int i = 0; i < samples; i++) {
+ final double x = sampler.sample();
+ final int index = findIndex(values, x);
+ observed[index]++;
+ }
+
+ // Compute expected
+ final NormalDistribution dist = new NormalDistribution(null, 0.0, 1.0);
+ final double[] expected = new double[bins];
+ double x0 = Double.NEGATIVE_INFINITY;
+ for (int i = 0; i < bins; i++) {
+ final double x1 = values[i];
+ expected[i] = dist.probability(x0, x1);
+ x0 = x1;
+ }
+
+ final ChiSquareTest chiSquareTest = new ChiSquareTest();
+ // Pass if we cannot reject null hypothesis that the distributions are the same.
+ double chi2 = chiSquareTest.chiSquareTest(expected, observed);
+ Assert.assertFalse("Chi-square p-value = " + chi2, chi2 < 0.001);
+
+ // Test around the mean
+ for (final double range : new double[] {0.5, 0.25, 0.1, 0.05}) {
+ final int min = findIndex(values, -range);
+ final int max = findIndex(values, range);
+ final long[] observed2 = Arrays.copyOfRange(observed, min, max + 1);
+ final double[] expected2 = Arrays.copyOfRange(expected, min, max + 1);
+ chi2 = chiSquareTest.chiSquareTest(expected2, observed2);
+ Assert.assertFalse(String.format("(%s <= x < %s) Chi-square p-value = %s",
+ -range, range, chi2), chi2 < 0.001);
+ }
+ }
+
+ /**
+ * Find the index of the value in the data such that:
+ * <pre>
+ * data[index - 1] <= x < data[index]
+ * </pre>
+ *
+ * @param data the data
+ * @param x the value
+ * @return the index
+ */
+ private static int findIndex(double[] data, double x) {
+ int low = 0;
+ int high = data.length - 1;
+
+ // Bracket so that low is just above the value x
+ while (low <= high) {
+ final int mid = (low + high) >>> 1;
+ final double midVal = data[mid];
+
+ if (x < midVal) {
+ // Reduce search range
+ high = mid - 1;
+ } else {
+ // Set data[low] above the value
+ low = mid + 1;
+ }
+ }
+ // Verify the index is correct
+ Assert.assertTrue(x < data[low]);
+ if (low != 0) {
+ Assert.assertTrue(x >= data[low - 1]);
+ }
+ return low;
+ }
}