You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2021/09/15 23:43:58 UTC
[commons-rng] 02/02: Added ternary variations to the ziggurat
benchmark
This is an automated email from the ASF dual-hosted git repository.
aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit 55e02d4387eb9eae48f93e717b7bd65baafc63fd
Author: Alex Herbert <ah...@apache.org>
AuthorDate: Thu Sep 16 00:43:53 2021 +0100
Added ternary variations to the ziggurat benchmark
---
.../distribution/ZigguratSamplerPerformance.java | 248 ++++++++++++++++++++-
.../sampling/distribution/ZigguratSamplerTest.java | 2 +
2 files changed, 248 insertions(+), 2 deletions(-)
diff --git a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java
index e136cf2..3a52182 100644
--- a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java
+++ b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java
@@ -19,6 +19,7 @@ package org.apache.commons.rng.examples.jmh.sampling.distribution;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.ObjectSampler;
import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
import org.apache.commons.rng.sampling.distribution.DiscreteSampler;
import org.apache.commons.rng.sampling.distribution.LongSampler;
@@ -88,6 +89,8 @@ public class ZigguratSamplerPerformance {
static final String MOD_GAUSSIAN_E_MAX_TABLE = "ModGaussianEMaxTable";
/** The name for the {@link ModifiedZigguratNormalizedGaussianSamplerEMax2}. */
static final String MOD_GAUSSIAN_E_MAX_2 = "ModGaussianEMax2";
+ /** The name for the {@link ModifiedZigguratNormalizedGaussianSamplerTernary}. */
+ static final String MOD_GAUSSIAN_TERNARY = "ModGaussianTernary";
/** The name for the {@link ModifiedZigguratNormalizedGaussianSampler512} using a table size of 512. */
static final String MOD_GAUSSIAN_512 = "ModGaussian512";
@@ -110,6 +113,8 @@ public class ZigguratSamplerPerformance {
static final String MOD_EXPONENTIAL_E_MAX_TABLE = "ModExponentialEmaxTable";
/** The name for the {@link ModifiedZigguratExponentialSamplerEMax2}. */
static final String MOD_EXPONENTIAL_E_MAX_2 = "ModExponentialEmax2";
+ /** The name for the {@link ModifiedZigguratExponentialSamplerTernary}. */
+ static final String MOD_EXPONENTIAL_TERNARY = "ModExponentialTernary";
/** The name for the {@link ModifiedZigguratExponentialSampler512} using a table size of 512. */
static final String MOD_EXPONENTIAL_512 = "ModExponential512";
@@ -464,6 +469,74 @@ public class ZigguratSamplerPerformance {
}
/**
+ * Defines method to use for creating two random {@code long} values in ascending order.
+ */
+ @State(Scope.Benchmark)
+ public static class DiffSources {
+ /** The method to obtain the long. */
+ @Param({"None", "Branch", "Ternary", "TernarySubtract"})
+ private String method;
+
+ /** The sampler. */
+ private ObjectSampler<long[]> sampler;
+
+ /**
+ * @return the sampler.
+ */
+ public ObjectSampler<long[]> getSampler() {
+ return sampler;
+ }
+
+ /** Instantiates sampler. */
+ @Setup
+ public void setup() {
+ // Use a fast generator:
+ final UniformRandomProvider rng = RandomSource.XO_RO_SHI_RO_128_PP.create();
+ final long[] tmp = new long[2];
+ if ("None".equals(method)) {
+ sampler = () -> {
+ tmp[0] = rng.nextLong() >>> 1;
+ tmp[1] = rng.nextLong() >>> 1;
+ return tmp;
+ };
+ } else if ("Branch".equals(method)) {
+ sampler = () -> {
+ long u1 = rng.nextLong() >>> 1;
+ long uDistance = (rng.nextLong() >>> 1) - u1;
+ if (uDistance < 0) {
+ // Upper-right triangle. Reflect in hypotenuse.
+ uDistance = -uDistance;
+ // Update u1 to be min(u1, u2) by subtracting the distance between them
+ u1 -= uDistance;
+ }
+ tmp[0] = u1;
+ tmp[1] = uDistance;
+ return tmp;
+ };
+ } else if ("Ternary".equals(method)) {
+ sampler = () -> {
+ long u1 = rng.nextLong() >>> 1;
+ long u2 = rng.nextLong() >>> 1;
+ tmp[0] = u1 < u2 ? u1 : u2;
+ tmp[1] = u1 < u2 ? u2 : u1;
+ return tmp;
+ };
+ } else if ("TernarySubtract".equals(method)) {
+ sampler = () -> {
+ long u1 = rng.nextLong() >>> 1;
+ long u2 = rng.nextLong() >>> 1;
+ long u = u1 < u2 ? u1 : u2;
+ tmp[0] = u;
+ tmp[1] = u1 + u2 - u;
+ return tmp;
+ };
+ } else {
+ throwIllegalStateException(method);
+ }
+ }
+ }
+
+ /**
* Update the state of the linear congruential generator.
* <pre>
* s = m*s + a
@@ -500,12 +573,14 @@ public class ZigguratSamplerPerformance {
MOD_GAUSSIAN2, MOD_GAUSSIAN_SIMPLE_OVERHANGS,
MOD_GAUSSIAN_INLINING, MOD_GAUSSIAN_INLINING_SHIFT,
MOD_GAUSSIAN_INLINING_SIMPLE_OVERHANGS, MOD_GAUSSIAN_INT_MAP,
- MOD_GAUSSIAN_E_MAX_TABLE, MOD_GAUSSIAN_E_MAX_2, MOD_GAUSSIAN_512,
+ MOD_GAUSSIAN_E_MAX_TABLE, MOD_GAUSSIAN_E_MAX_2,
+ MOD_GAUSSIAN_TERNARY, MOD_GAUSSIAN_512,
// Experimental McFarland Gaussian ziggurat samplers
MOD_EXPONENTIAL2, MOD_EXPONENTIAL_SIMPLE_OVERHANGS, MOD_EXPONENTIAL_INLINING,
MOD_EXPONENTIAL_LOOP, MOD_EXPONENTIAL_LOOP2,
MOD_EXPONENTIAL_RECURSION, MOD_EXPONENTIAL_INT_MAP,
- MOD_EXPONENTIAL_E_MAX_TABLE, MOD_EXPONENTIAL_E_MAX_2, MOD_EXPONENTIAL_512})
+ MOD_EXPONENTIAL_E_MAX_TABLE, MOD_EXPONENTIAL_E_MAX_2,
+ MOD_EXPONENTIAL_TERNARY, MOD_EXPONENTIAL_512})
protected String type;
/**
@@ -542,6 +617,8 @@ public class ZigguratSamplerPerformance {
return new ModifiedZigguratNormalizedGaussianSamplerEMaxTable(rng);
} else if (MOD_GAUSSIAN_E_MAX_2.equals(type)) {
return new ModifiedZigguratNormalizedGaussianSamplerEMax2(rng);
+ } else if (MOD_GAUSSIAN_TERNARY.equals(type)) {
+ return new ModifiedZigguratNormalizedGaussianSamplerTernary(rng);
} else if (MOD_GAUSSIAN_512.equals(type)) {
return new ModifiedZigguratNormalizedGaussianSampler512(rng);
} else if (MOD_EXPONENTIAL2.equals(type)) {
@@ -562,6 +639,8 @@ public class ZigguratSamplerPerformance {
return new ModifiedZigguratExponentialSamplerEMaxTable(rng);
} else if (MOD_EXPONENTIAL_E_MAX_2.equals(type)) {
return new ModifiedZigguratExponentialSamplerEMax2(rng);
+ } else if (MOD_EXPONENTIAL_TERNARY.equals(type)) {
+ return new ModifiedZigguratExponentialSamplerTernary(rng);
} else if (MOD_EXPONENTIAL_512.equals(type)) {
return new ModifiedZigguratExponentialSampler512(rng);
} else {
@@ -2331,6 +2410,108 @@ public class ZigguratSamplerPerformance {
* <p>Uses the algorithm from McFarland, C.D. (2016).
*
* <p>This is a copy of {@link ModifiedZigguratNormalizedGaussianSampler} using
+ * a ternary operator to sort the two random long values.
+ */
+ static class ModifiedZigguratNormalizedGaussianSamplerTernary
+ extends ModifiedZigguratNormalizedGaussianSampler {
+
+ /**
+ * @param rng Generator of uniformly distributed random numbers.
+ */
+ ModifiedZigguratNormalizedGaussianSamplerTernary(UniformRandomProvider rng) {
+ super(rng);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public double sample() {
+ final long xx = nextLong();
+ // Float multiplication squashes these last 8 bits, so they can be used to sample i
+ final int i = ((int) xx) & 0xff;
+
+ if (i < I_MAX) {
+ // Early exit.
+ return X[i] * xx;
+ }
+
+ // Recycle bits then advance RNG:
+ long u1 = xx & MAX_INT64;
+ // Another squashed, recyclable bit
+ // double sign_bit = u1 & 0x100 ? 1. : -1.
+ // Use 2 - 1 or 0 - 1
+ final double signBit = ((u1 >>> 7) & 0x2) - 1.0;
+ final int j = selectRegion();
+ // Four kinds of overhangs:
+ // j = 0 : Sample from tail
+ // 0 < j < J_INFLECTION : Overhang is concave; only sample from Lower-Left triangle
+ // j = J_INFLECTION : Must sample from entire overhang rectangle
+ // j > J_INFLECTION : Overhangs are convex; implicitly accept point in Lower-Left triangle
+ //
+ // Conditional statements are arranged such that the more likely outcomes are first.
+ double x;
+ if (j > J_INFLECTION) {
+ // Convex overhang
+ for (;;) {
+ x = sampleX(X, j, u1);
+ final long uDistance = randomInt63() - u1;
+ if (uDistance >= 0) {
+ // Lower-left triangle
+ break;
+ }
+ if (uDistance >= CONVEX_E_MAX &&
+ // Within maximum distance of f(x) from the triangle hypotenuse.
+ sampleY(Y, j, u1 + uDistance) < Math.exp(-0.5 * x * x)) {
+ break;
+ }
+ // uDistance < E_MAX (upper-right triangle) or rejected as above the curve
+ u1 = randomInt63();
+ }
+ } else if (j < J_INFLECTION) {
+ if (j == 0) {
+ // Tail
+ // Note: Although less frequent than the next branch, j == 0 is a subset of
+ // j < J_INFLECTION and must be first.
+ do {
+ x = ONE_OVER_X_0 * exponential.sample();
+ } while (exponential.sample() < 0.5 * x * x);
+ x += X_0;
+ } else {
+ // Concave overhang
+ for (;;) {
+ // If u2 < u1 then reflect in the hypotenuse by swapping u1 and u2.
+ final long ua = u1;
+ final long ub = randomInt63();
+ // Sort u1 < u2 to sample the lower-left triangle
+ u1 = ua < ub ? ua : ub;
+ final long u2 = ua < ub ? ub : ua;
+ x = sampleX(X, j, u1);
+ if (u2 - u1 > CONCAVE_E_MAX ||
+ sampleY(Y, j, u2) < Math.exp(-0.5 * x * x)) {
+ break;
+ }
+ u1 = randomInt63();
+ }
+ }
+ } else {
+ // Inflection point
+ for (;;) {
+ x = sampleX(X, j, u1);
+ if (sampleY(Y, j, randomInt63()) < Math.exp(-0.5 * x * x)) {
+ break;
+ }
+ u1 = randomInt63();
+ }
+ }
+ return signBit * x;
+ }
+ }
+
+ /**
+ * Modified Ziggurat method for sampling from a Gaussian distribution with mean 0 and standard deviation 1.
+ *
+ * <p>Uses the algorithm from McFarland, C.D. (2016).
+ *
+ * <p>This is a copy of {@link ModifiedZigguratNormalizedGaussianSampler} using
* a table size of 512.
*/
static class ModifiedZigguratNormalizedGaussianSampler512 implements ContinuousSampler {
@@ -3940,6 +4121,53 @@ public class ZigguratSamplerPerformance {
* <p>Uses the algorithm from McFarland, C.D. (2016).
*
* <p>This is a copy of {@link ModifiedZigguratExponentialSampler} using
+ * a ternary operator to sort the two random long values.
+ */
+ static class ModifiedZigguratExponentialSamplerTernary
+ extends ModifiedZigguratExponentialSampler {
+
+ /**
+ * @param rng Generator of uniformly distributed random numbers.
+ */
+ ModifiedZigguratExponentialSamplerTernary(UniformRandomProvider rng) {
+ super(rng);
+ }
+
+ @Override
+ protected double sampleOverhang(int j) {
+ // Sample from the triangle:
+ // X[j],Y[j]
+ // |\-->u1
+ // | \ |
+ // | \ |
+ // | \| Overhang j (with hypotenuse not pdf(x))
+ // | \
+ // | |\
+ // | | \
+ // | u2 \
+ // +-------- X[j-1],Y[j-1]
+ // If u2 < u1 then reflect in the hypotenuse by swapping u1 and u2.
+ final long ua = randomInt63();
+ final long ub = randomInt63();
+ // Sort u1 < u2 to sample the lower-left triangle
+ final long u1 = ua < ub ? ua : ub;
+ final long u2 = ua < ub ? ub : ua;
+ final double x = sampleX(X, j, u1);
+ if (u2 - u1 >= E_MAX) {
+ // Early Exit: x < y - epsilon
+ return x;
+ }
+
+ return sampleY(Y, j, u2) <= Math.exp(-x) ? x : sampleOverhang(j);
+ }
+ }
+
+ /**
+ * Modified Ziggurat method for sampling from an exponential distribution.
+ *
+ * <p>Uses the algorithm from McFarland, C.D. (2016).
+ *
+ * <p>This is a copy of {@link ModifiedZigguratExponentialSampler} using
* a table size of 512.
*/
static class ModifiedZigguratExponentialSampler512 implements ContinuousSampler {
@@ -4652,6 +4880,22 @@ public class ZigguratSamplerPerformance {
}
/**
+ * Benchmark methods for obtaining 2 random longs in ascending order.
+ *
+ * <p>Note: This is disabled. The branchless versions using a ternary
+ * conditional assignment are faster. This may not manifest as a performance
+ * improvement when used in the ziggurat sampler as it is not on the
+ * hot path (i.e. sampling inside the ziggurat).
+ *
+ * @param sources Source of randomness.
+ * @return the sample value
+ */
+ //@Benchmark
+ public long[] diff(DiffSources sources) {
+ return sources.getSampler().sample();
+ }
+
+ /**
* Benchmark methods for obtaining {@code exp(z)} when {@code -8 <= z <= 0}.
*
* <p>Note: This is disabled. On JDK 8 FastMath is faster. On JDK 11 Math.exp is
diff --git a/commons-rng-examples/examples-jmh/src/test/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerTest.java b/commons-rng-examples/examples-jmh/src/test/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerTest.java
index a58f325..25d0670 100644
--- a/commons-rng-examples/examples-jmh/src/test/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerTest.java
+++ b/commons-rng-examples/examples-jmh/src/test/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerTest.java
@@ -91,6 +91,7 @@ class ZigguratSamplerTest {
args(ZigguratSamplerPerformance.MOD_GAUSSIAN_INT_MAP),
args(ZigguratSamplerPerformance.MOD_GAUSSIAN_E_MAX_TABLE),
args(ZigguratSamplerPerformance.MOD_GAUSSIAN_E_MAX_2),
+ args(ZigguratSamplerPerformance.MOD_GAUSSIAN_TERNARY),
args(ZigguratSamplerPerformance.MOD_GAUSSIAN_512));
}
@@ -115,6 +116,7 @@ class ZigguratSamplerTest {
args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_INT_MAP),
args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_E_MAX_TABLE),
args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_E_MAX_2),
+ args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_TERNARY),
args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_512));
}