You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2021/09/15 23:43:58 UTC

[commons-rng] 02/02: Added ternary variations to the ziggurat benchmark

This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git

commit 55e02d4387eb9eae48f93e717b7bd65baafc63fd
Author: Alex Herbert <ah...@apache.org>
AuthorDate: Thu Sep 16 00:43:53 2021 +0100

    Added ternary variations to the ziggurat benchmark
---
 .../distribution/ZigguratSamplerPerformance.java   | 248 ++++++++++++++++++++-
 .../sampling/distribution/ZigguratSamplerTest.java |   2 +
 2 files changed, 248 insertions(+), 2 deletions(-)

diff --git a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java
index e136cf2..3a52182 100644
--- a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java
+++ b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java
@@ -19,6 +19,7 @@ package org.apache.commons.rng.examples.jmh.sampling.distribution;
 
 import org.apache.commons.math3.util.FastMath;
 import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.ObjectSampler;
 import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
 import org.apache.commons.rng.sampling.distribution.DiscreteSampler;
 import org.apache.commons.rng.sampling.distribution.LongSampler;
@@ -88,6 +89,8 @@ public class ZigguratSamplerPerformance {
     static final String MOD_GAUSSIAN_E_MAX_TABLE = "ModGaussianEMaxTable";
     /** The name for the {@link ModifiedZigguratNormalizedGaussianSamplerEMax2}. */
     static final String MOD_GAUSSIAN_E_MAX_2 = "ModGaussianEMax2";
+    /** The name for the {@link ModifiedZigguratNormalizedGaussianSamplerTernary}. */
+    static final String MOD_GAUSSIAN_TERNARY = "ModGaussianTernary";
     /** The name for the {@link ModifiedZigguratNormalizedGaussianSampler512} using a table size of 512. */
     static final String MOD_GAUSSIAN_512 = "ModGaussian512";
 
@@ -110,6 +113,8 @@ public class ZigguratSamplerPerformance {
     static final String MOD_EXPONENTIAL_E_MAX_TABLE = "ModExponentialEmaxTable";
     /** The name for the {@link ModifiedZigguratExponentialSamplerEMax2}. */
     static final String MOD_EXPONENTIAL_E_MAX_2 = "ModExponentialEmax2";
+    /** The name for the {@link ModifiedZigguratExponentialSamplerTernary}. */
+    static final String MOD_EXPONENTIAL_TERNARY = "ModExponentialTernary";
     /** The name for the {@link ModifiedZigguratExponentialSampler512} using a table size of 512. */
     static final String MOD_EXPONENTIAL_512 = "ModExponential512";
 
@@ -464,6 +469,74 @@ public class ZigguratSamplerPerformance {
     }
 
     /**
+     * Defines method to use for creating two random {@code long} values in ascending order.
+     */
+    @State(Scope.Benchmark)
+    public static class DiffSources {
+        /** The method to obtain the long. */
+        @Param({"None", "Branch", "Ternary", "TernarySubtract"})
+        private String method;
+
+        /** The sampler. */
+        private ObjectSampler<long[]> sampler;
+
+        /**
+         * @return the sampler.
+         */
+        public ObjectSampler<long[]> getSampler() {
+            return sampler;
+        }
+
+        /** Instantiates sampler. */
+        @Setup
+        public void setup() {
+            // Use a fast generator:
+            final UniformRandomProvider rng = RandomSource.XO_RO_SHI_RO_128_PP.create();
+            final long[] tmp = new long[2];
+            if ("None".equals(method)) {
+                sampler = () -> {
+                    tmp[0] = rng.nextLong() >>> 1;
+                    tmp[1] = rng.nextLong() >>> 1;
+                    return tmp;
+                };
+            } else if ("Branch".equals(method)) {
+                sampler = () -> {
+                    long u1 = rng.nextLong() >>> 1;
+                    long uDistance = (rng.nextLong() >>> 1) - u1;
+                    if (uDistance < 0) {
+                        // Upper-right triangle. Reflect in hypotenuse.
+                        uDistance = -uDistance;
+                        // Update u1 to be min(u1, u2) by subtracting the distance between them
+                        u1 -= uDistance;
+                    }
+                    tmp[0] = u1;
+                    tmp[1] = uDistance;
+                    return tmp;
+                };
+            } else if ("Ternary".equals(method)) {
+                sampler = () -> {
+                    long u1 = rng.nextLong() >>> 1;
+                    long u2 = rng.nextLong() >>> 1;
+                    tmp[0] = u1 < u2 ? u1 : u2;
+                    tmp[1] = u1 < u2 ? u2 : u1;
+                    return tmp;
+                };
+            } else if ("TernarySubtract".equals(method)) {
+                sampler = () -> {
+                    long u1 = rng.nextLong() >>> 1;
+                    long u2 = rng.nextLong() >>> 1;
+                    long u = u1 < u2 ? u1 : u2;
+                    tmp[0] = u;
+                    tmp[1] = u1 + u2 - u;
+                    return tmp;
+                };
+            } else {
+                throwIllegalStateException(method);
+            }
+        }
+    }
+
+    /**
      * Update the state of the linear congruential generator.
      * <pre>
      *  s = m*s + a
@@ -500,12 +573,14 @@ public class ZigguratSamplerPerformance {
                 MOD_GAUSSIAN2, MOD_GAUSSIAN_SIMPLE_OVERHANGS,
                 MOD_GAUSSIAN_INLINING, MOD_GAUSSIAN_INLINING_SHIFT,
                 MOD_GAUSSIAN_INLINING_SIMPLE_OVERHANGS, MOD_GAUSSIAN_INT_MAP,
-                MOD_GAUSSIAN_E_MAX_TABLE, MOD_GAUSSIAN_E_MAX_2, MOD_GAUSSIAN_512,
+                MOD_GAUSSIAN_E_MAX_TABLE, MOD_GAUSSIAN_E_MAX_2,
+                MOD_GAUSSIAN_TERNARY, MOD_GAUSSIAN_512,
                 // Experimental McFarland Gaussian ziggurat samplers
                 MOD_EXPONENTIAL2, MOD_EXPONENTIAL_SIMPLE_OVERHANGS, MOD_EXPONENTIAL_INLINING,
                 MOD_EXPONENTIAL_LOOP, MOD_EXPONENTIAL_LOOP2,
                 MOD_EXPONENTIAL_RECURSION, MOD_EXPONENTIAL_INT_MAP,
-                MOD_EXPONENTIAL_E_MAX_TABLE, MOD_EXPONENTIAL_E_MAX_2, MOD_EXPONENTIAL_512})
+                MOD_EXPONENTIAL_E_MAX_TABLE, MOD_EXPONENTIAL_E_MAX_2,
+                MOD_EXPONENTIAL_TERNARY, MOD_EXPONENTIAL_512})
         protected String type;
 
         /**
@@ -542,6 +617,8 @@ public class ZigguratSamplerPerformance {
                 return new ModifiedZigguratNormalizedGaussianSamplerEMaxTable(rng);
             } else if (MOD_GAUSSIAN_E_MAX_2.equals(type)) {
                 return new ModifiedZigguratNormalizedGaussianSamplerEMax2(rng);
+            } else if (MOD_GAUSSIAN_TERNARY.equals(type)) {
+                return new ModifiedZigguratNormalizedGaussianSamplerTernary(rng);
             } else if (MOD_GAUSSIAN_512.equals(type)) {
                 return new ModifiedZigguratNormalizedGaussianSampler512(rng);
             } else if (MOD_EXPONENTIAL2.equals(type)) {
@@ -562,6 +639,8 @@ public class ZigguratSamplerPerformance {
                 return new ModifiedZigguratExponentialSamplerEMaxTable(rng);
             } else if (MOD_EXPONENTIAL_E_MAX_2.equals(type)) {
                 return new ModifiedZigguratExponentialSamplerEMax2(rng);
+            } else if (MOD_EXPONENTIAL_TERNARY.equals(type)) {
+                return new ModifiedZigguratExponentialSamplerTernary(rng);
             } else if (MOD_EXPONENTIAL_512.equals(type)) {
                 return new ModifiedZigguratExponentialSampler512(rng);
             } else {
@@ -2331,6 +2410,108 @@ public class ZigguratSamplerPerformance {
      * <p>Uses the algorithm from McFarland, C.D. (2016).
      *
      * <p>This is a copy of {@link ModifiedZigguratNormalizedGaussianSampler} using
+     * a ternary operator to sort the two random long values.
+     */
+    static class ModifiedZigguratNormalizedGaussianSamplerTernary
+        extends ModifiedZigguratNormalizedGaussianSampler {
+
+        /**
+         * @param rng Generator of uniformly distributed random numbers.
+         */
+        ModifiedZigguratNormalizedGaussianSamplerTernary(UniformRandomProvider rng) {
+            super(rng);
+        }
+
+        /** {@inheritDoc} */
+        @Override
+        public double sample() {
+            final long xx = nextLong();
+            // Float multiplication squashes these last 8 bits, so they can be used to sample i
+            final int i = ((int) xx) & 0xff;
+
+            if (i < I_MAX) {
+                // Early exit.
+                return X[i] * xx;
+            }
+
+            // Recycle bits then advance RNG:
+            long u1 = xx & MAX_INT64;
+            // Another squashed, recyclable bit
+            // double sign_bit = u1 & 0x100 ? 1. : -1.
+            // Use 2 - 1 or 0 - 1
+            final double signBit = ((u1 >>> 7) & 0x2) - 1.0;
+            final int j = selectRegion();
+            // Four kinds of overhangs:
+            //  j = 0                :  Sample from tail
+            //  0 < j < J_INFLECTION :  Overhang is concave; only sample from Lower-Left triangle
+            //  j = J_INFLECTION     :  Must sample from entire overhang rectangle
+            //  j > J_INFLECTION     :  Overhangs are convex; implicitly accept point in Lower-Left triangle
+            //
+            // Conditional statements are arranged such that the more likely outcomes are first.
+            double x;
+            if (j > J_INFLECTION) {
+                // Convex overhang
+                for (;;) {
+                    x = sampleX(X, j, u1);
+                    final long uDistance = randomInt63() - u1;
+                    if (uDistance >= 0) {
+                        // Lower-left triangle
+                        break;
+                    }
+                    if (uDistance >= CONVEX_E_MAX &&
+                        // Within maximum distance of f(x) from the triangle hypotenuse.
+                        sampleY(Y, j, u1 + uDistance) < Math.exp(-0.5 * x * x)) {
+                        break;
+                    }
+                    // uDistance < E_MAX (upper-right triangle) or rejected as above the curve
+                    u1 = randomInt63();
+                }
+            } else if (j < J_INFLECTION) {
+                if (j == 0) {
+                    // Tail
+                    // Note: Although less frequent than the next branch, j == 0 is a subset of
+                    // j < J_INFLECTION and must be first.
+                    do {
+                        x = ONE_OVER_X_0 * exponential.sample();
+                    } while (exponential.sample() < 0.5 * x * x);
+                    x += X_0;
+                } else {
+                    // Concave overhang
+                    for (;;) {
+                        // If u2 < u1 then reflect in the hypotenuse by swapping u1 and u2.
+                        final long ua = u1;
+                        final long ub = randomInt63();
+                        // Sort u1 < u2 to sample the lower-left triangle
+                        u1 = ua < ub ? ua : ub;
+                        final long u2 = ua < ub ? ub : ua;
+                        x = sampleX(X, j, u1);
+                        if (u2 - u1 > CONCAVE_E_MAX ||
+                            sampleY(Y, j, u2) < Math.exp(-0.5 * x * x)) {
+                            break;
+                        }
+                        u1 = randomInt63();
+                    }
+                }
+            } else {
+                // Inflection point
+                for (;;) {
+                    x = sampleX(X, j, u1);
+                    if (sampleY(Y, j, randomInt63()) < Math.exp(-0.5 * x * x)) {
+                        break;
+                    }
+                    u1 = randomInt63();
+                }
+            }
+            return signBit * x;
+        }
+    }
+
+    /**
+     * Modified Ziggurat method for sampling from a Gaussian distribution with mean 0 and standard deviation 1.
+     *
+     * <p>Uses the algorithm from McFarland, C.D. (2016).
+     *
+     * <p>This is a copy of {@link ModifiedZigguratNormalizedGaussianSampler} using
      * a table size of 512.
      */
     static class ModifiedZigguratNormalizedGaussianSampler512 implements ContinuousSampler {
@@ -3940,6 +4121,53 @@ public class ZigguratSamplerPerformance {
      * <p>Uses the algorithm from McFarland, C.D. (2016).
      *
      * <p>This is a copy of {@link ModifiedZigguratExponentialSampler} using
+     * a ternary operator to sort the two random long values.
+     */
+    static class ModifiedZigguratExponentialSamplerTernary
+        extends ModifiedZigguratExponentialSampler {
+
+        /**
+         * @param rng Generator of uniformly distributed random numbers.
+         */
+        ModifiedZigguratExponentialSamplerTernary(UniformRandomProvider rng) {
+            super(rng);
+        }
+
+        @Override
+        protected double sampleOverhang(int j) {
+            // Sample from the triangle:
+            //    X[j],Y[j]
+            //        |\-->u1
+            //        | \  |
+            //        |  \ |
+            //        |   \|    Overhang j (with hypotenuse not pdf(x))
+            //        |    \
+            //        |    |\
+            //        |    | \
+            //        |    u2 \
+            //        +-------- X[j-1],Y[j-1]
+            // If u2 < u1 then reflect in the hypotenuse by swapping u1 and u2.
+            final long ua = randomInt63();
+            final long ub = randomInt63();
+            // Sort u1 < u2 to sample the lower-left triangle
+            final long u1 = ua < ub ? ua : ub;
+            final long u2 = ua < ub ? ub : ua;
+            final double x = sampleX(X, j, u1);
+            if (u2 - u1 >= E_MAX) {
+                // Early Exit: x < y - epsilon
+                return x;
+            }
+
+            return sampleY(Y, j, u2) <= Math.exp(-x) ? x : sampleOverhang(j);
+        }
+    }
+
+    /**
+     * Modified Ziggurat method for sampling from an exponential distribution.
+     *
+     * <p>Uses the algorithm from McFarland, C.D. (2016).
+     *
+     * <p>This is a copy of {@link ModifiedZigguratExponentialSampler} using
      * a table size of 512.
      */
     static class ModifiedZigguratExponentialSampler512 implements ContinuousSampler {
@@ -4652,6 +4880,22 @@ public class ZigguratSamplerPerformance {
     }
 
     /**
+     * Benchmark methods for obtaining 2 random longs in ascending order.
+     *
+     * <p>Note: This is disabled. The branchless versions using a ternary
+     * conditional assignment are faster. This may not manifest as a performance
+     * improvement when used in the ziggurat sampler as it is not on the
+     * hot path (i.e. sampling inside the ziggurat).
+     *
+     * @param sources Source of randomness.
+     * @return the sample value
+     */
+    //@Benchmark
+    public long[] diff(DiffSources sources) {
+        return sources.getSampler().sample();
+    }
+
+    /**
      * Benchmark methods for obtaining {@code exp(z)} when {@code -8 <= z <= 0}.
      *
      * <p>Note: This is disabled. On JDK 8 FastMath is faster. On JDK 11 Math.exp is
diff --git a/commons-rng-examples/examples-jmh/src/test/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerTest.java b/commons-rng-examples/examples-jmh/src/test/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerTest.java
index a58f325..25d0670 100644
--- a/commons-rng-examples/examples-jmh/src/test/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerTest.java
+++ b/commons-rng-examples/examples-jmh/src/test/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerTest.java
@@ -91,6 +91,7 @@ class ZigguratSamplerTest {
             args(ZigguratSamplerPerformance.MOD_GAUSSIAN_INT_MAP),
             args(ZigguratSamplerPerformance.MOD_GAUSSIAN_E_MAX_TABLE),
             args(ZigguratSamplerPerformance.MOD_GAUSSIAN_E_MAX_2),
+            args(ZigguratSamplerPerformance.MOD_GAUSSIAN_TERNARY),
             args(ZigguratSamplerPerformance.MOD_GAUSSIAN_512));
     }
 
@@ -115,6 +116,7 @@ class ZigguratSamplerTest {
                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_INT_MAP),
                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_E_MAX_TABLE),
                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_E_MAX_2),
+                args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_TERNARY),
                 args(ZigguratSamplerPerformance.MOD_EXPONENTIAL_512));
     }