You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2021/08/27 14:48:32 UTC

[commons-rng] 05/06: Add performance test for extracting the sign bit

This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git

commit f97188b337fc87fed6265ed908c69950ef9fdf7a
Author: Alex Herbert <ah...@apache.org>
AuthorDate: Fri Aug 27 14:49:58 2021 +0100

    Add performance test for extracting the sign bit
    
    Update tests requiring a fast generator to use a LCG. Only applicable
    when the upper bits are required.
---
 .../distribution/ZigguratSamplerPerformance.java   | 125 +++++++++++++++++++--
 1 file changed, 114 insertions(+), 11 deletions(-)

diff --git a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java
index ce1db12..58d7747 100644
--- a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java
+++ b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java
@@ -36,6 +36,7 @@ import org.openjdk.jmh.annotations.Setup;
 import org.openjdk.jmh.annotations.State;
 import org.openjdk.jmh.annotations.Warmup;
 import java.util.Arrays;
+import java.util.concurrent.ThreadLocalRandom;
 import java.util.concurrent.TimeUnit;
 
 /**
@@ -182,7 +183,7 @@ public class ZigguratSamplerPerformance {
             // exponential = 252
             // gaussian = 253
             final int limit = 253;
-            // Use a fast generator
+            // Use a fast generator:
             final UniformRandomProvider rng = RandomSource.XO_RO_SHI_RO_128_PP.create();
             if ("CastMaskIntCompare".equals(method)) {
                 sampler = () -> {
@@ -239,12 +240,22 @@ public class ZigguratSamplerPerformance {
         /** Instantiates sampler. */
         @Setup
         public void setup() {
-            // Use a fast generator
-            final UniformRandomProvider rng = RandomSource.XO_RO_SHI_RO_128_PP.create();
+            // Use a fast generator:
+            // Here we use a simple linear congruential generator
+            // which should have constant speed and a random upper bit.
+            final long[] s = {ThreadLocalRandom.current().nextLong()};
             if ("Mask".equals(method)) {
-                sampler = () -> rng.nextLong() & Long.MAX_VALUE;
+                sampler = () -> {
+                    final long x = s[0];
+                    s[0] = updateLCG(x);
+                    return x & Long.MAX_VALUE;
+                };
             } else if ("Shift".equals(method)) {
-                sampler = () -> rng.nextLong() >>> 1;
+                sampler = () -> {
+                    final long x = s[0];
+                    s[0] = updateLCG(x);
+                    return x >>> 1;
+                };
             } else {
                 throwIllegalStateException(method);
             }
@@ -256,7 +267,7 @@ public class ZigguratSamplerPerformance {
      */
     @State(Scope.Benchmark)
     public static class InterpolationSources {
-        /** The method to obtain the long. */
+        /** The method to perform interpolation. */
         @Param({"U1", "1minusU2", "U_1minusU"})
         private String method;
 
@@ -273,7 +284,7 @@ public class ZigguratSamplerPerformance {
         /** Instantiates sampler. */
         @Setup
         public void setup() {
-            // Use a fast generator
+            // Use a fast generator:
             final UniformRandomProvider rng = RandomSource.XO_RO_SHI_RO_128_PP.create();
             // Get an x table. This is length 254.
             // We will sample from this internally to avoid index out-of-bounds issues.
@@ -283,7 +294,7 @@ public class ZigguratSamplerPerformance {
             final int mask = 127;
             if ("U1".equals(method)) {
                 sampler = () -> {
-                    final long u = rng.nextLong() >>> 1;
+                    final long u = rng.nextLong();
                     final int j = 1 + (((int) u) & mask);
                     // double multiply
                     // double add
@@ -297,7 +308,7 @@ public class ZigguratSamplerPerformance {
                 };
             } else if ("1minusU2".equals(method)) {
                 sampler = () -> {
-                    final long u = rng.nextLong() >>> 1;
+                    final long u = rng.nextLong();
                     final int j = 1 + (((int) u) & mask);
                     // Since u is in [0, 2^63) to create (1 - u) using Long.MIN_VALUE
                     // as an unsigned integer of 2^63.
@@ -314,7 +325,7 @@ public class ZigguratSamplerPerformance {
                 };
             } else if ("U_1minusU".equals(method)) {
                 sampler = () -> {
-                    final long u = rng.nextLong() >>> 1;
+                    final long u = rng.nextLong();
                     final int j = 1 + (((int) u) & mask);
                     // Interpolation between bounds a and b using:
                     // a * u + b * (1 - u) == b + u * (a - b)
@@ -335,6 +346,83 @@ public class ZigguratSamplerPerformance {
     }
 
     /**
+     * Defines method to extract a sign bit from a {@code long} value.
+     */
+    @State(Scope.Benchmark)
+    public static class SignBitSources {
+        /** The method to obtain the sign bit. */
+        @Param({"ifNegative", "ifSignBit", "ifBit", "bitSubtract", "signBitSubtract"})
+        private String method;
+
+        /** The sampler. */
+        private ContinuousSampler sampler;
+
+        /**
+         * @return the sampler.
+         */
+        public ContinuousSampler getSampler() {
+            return sampler;
+        }
+
+        /** Instantiates sampler. */
+        @Setup
+        public void setup() {
+            // Use a fast generator:
+            final UniformRandomProvider rng = RandomSource.XO_RO_SHI_RO_128_PP.create();
+
+            if ("ifNegative".equals(method)) {
+                sampler = () -> {
+                    final long x = rng.nextLong();
+                    return x < 0 ? -1.0 : 1.0;
+                };
+            } else if ("ifSignBit".equals(method)) {
+                sampler = () -> {
+                    final long x = rng.nextLong();
+                    return (x >>> 63) == 0 ? -1.0 : 1.0;
+                };
+            } else if ("ifBit".equals(method)) {
+                sampler = () -> {
+                    final long x = rng.nextLong();
+                    return (x & 0x100) == 0 ? -1.0 : 1.0;
+                };
+            } else if ("bitSubtract".equals(method)) {
+                sampler = () -> {
+                    final long x = rng.nextLong();
+                    return ((x >>> 7) & 0x2) - 1.0;
+                };
+            } else if ("signBitSubtract".equals(method)) {
+                sampler = () -> {
+                    final long x = rng.nextLong();
+                    return ((x >>> 62) & 0x2) - 1.0;
+                };
+            } else {
+                throwIllegalStateException(method);
+            }
+        }
+    }
+
+    /**
+     * Update the state of the linear congruential generator.
+     * <pre>
+     *  s = m*s + a
+     * </pre>
+     *
+     * <p>This can be used when the upper bits of the long are important.
+     * The lower bits will not be very random. Each bit has a period of
+     * 2^p where p is the bit significance.
+     *
+     * @param state the state
+     * @return the new state
+     */
+    private static long updateLCG(long state) {
+        // m is the multiplier used for the LCG component of the JDK 17 generators.
+        // a can be any odd number.
+        // Use the golden ratio from a SplitMix64 generator.
+        return 0xd1342543de82ef95L * state + 0x9e3779b97f4a7c15L;
+
+    }
+
+    /**
      * The samplers to use for testing the ziggurat method.
      */
     @State(Scope.Benchmark)
@@ -3827,7 +3915,8 @@ public class ZigguratSamplerPerformance {
     /**
      * Benchmark methods for obtaining an unsigned long.
      *
-     * <p>Note: This is disabled as there is no measurable difference between methods.
+     * <p>Note: This is disabled. Either there is no measurable difference between methods
+     * or the bit shift method is marginally faster depending on JDK and platform.
      *
      * @param sources Source of randomness.
      * @return the sample value
@@ -3851,6 +3940,20 @@ public class ZigguratSamplerPerformance {
     }
 
     /**
+     * Benchmark methods for obtaining a sign value from a long.
+     *
+     * <p>Note: This is disabled. The branchless versions using a subtraction of
+     * 2 - 1 or 0 - 1 are faster.
+     *
+     * @param sources Source of randomness.
+     * @return the sample value
+     */
+    //@Benchmark
+    public double signBit(SignBitSources sources) {
+        return sources.getSampler().sample();
+    }
+
+    /**
      * Run the sampler.
      *
      * @param sources Source of randomness.