You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2019/08/05 20:03:52 UTC

[commons-rng] 01/02: RNG-90: Use faster algorithm for nextInt(int).

This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git

commit 057ac5ae8c85dfbac0fcbdf84a43a9a7cea375e1
Author: Alex Herbert <ah...@apache.org>
AuthorDate: Sat Aug 3 13:08:38 2019 +0100

    RNG-90: Use faster algorithm for nextInt(int).
    
    See Lemire (2019): Fast Random Integer Generation in an Interval
    https://arxiv.org/abs/1805.10941
---
 .../org/apache/commons/rng/core/BaseProvider.java  |  47 +-
 .../examples/jmh/RngNextIntInRangeBenchmark.java   | 508 +++++++++++++++++++++
 2 files changed, 531 insertions(+), 24 deletions(-)

diff --git a/commons-rng-core/src/main/java/org/apache/commons/rng/core/BaseProvider.java b/commons-rng-core/src/main/java/org/apache/commons/rng/core/BaseProvider.java
index fa355c7..9313a3f 100644
--- a/commons-rng-core/src/main/java/org/apache/commons/rng/core/BaseProvider.java
+++ b/commons-rng-core/src/main/java/org/apache/commons/rng/core/BaseProvider.java
@@ -25,28 +25,39 @@ import org.apache.commons.rng.RandomProviderState;
  */
 public abstract class BaseProvider
     implements RestorableUniformRandomProvider {
+    /** Error message when an integer is not positive. */
+    private static final String NOT_POSITIVE = "Must be strictly positive: ";
+    /** 2^32. */
+    private static final long POW_32 = 1L << 32;
+
     /** {@inheritDoc} */
     @Override
     public int nextInt(int n) {
-        checkStrictlyPositive(n);
-
-        if ((n & -n) == n) {
-            return (int) ((n * (long) (nextInt() >>> 1)) >> 31);
+        if (n <= 0) {
+            throw new IllegalArgumentException(NOT_POSITIVE + n);
         }
-        int bits;
-        int val;
-        do {
-            bits = nextInt() >>> 1;
-            val = bits % n;
-        } while (bits - val + (n - 1) < 0);
 
-        return val;
+        // Lemire (2019): Fast Random Integer Generation in an Interval
+        // https://arxiv.org/abs/1805.10941
+        long m = (nextInt() & 0xffffffffL) * n;
+        long l = m & 0xffffffffL;
+        if (l < n) {
+            // 2^32 % n
+            final long t = POW_32 % n;
+            while (l < t) {
+                m = (nextInt() & 0xffffffffL) * n;
+                l = m & 0xffffffffL;
+            }
+        }
+        return (int) (m >>> 32);
     }
 
     /** {@inheritDoc} */
     @Override
     public long nextLong(long n) {
-        checkStrictlyPositive(n);
+        if (n <= 0) {
+            throw new IllegalArgumentException(NOT_POSITIVE + n);
+        }
 
         long bits;
         long val;
@@ -278,18 +289,6 @@ public abstract class BaseProvider
     }
 
     /**
-     * Checks that the argument is strictly positive.
-     *
-     * @param n Number to check.
-     * @throws IllegalArgumentException if {@code n <= 0}.
-     */
-    private void checkStrictlyPositive(long n) {
-        if (n <= 0) {
-            throw new IllegalArgumentException("Must be strictly positive: " + n);
-        }
-    }
-
-    /**
      * Transformation used to scramble the initial state of
      * a generator.
      *
diff --git a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/RngNextIntInRangeBenchmark.java b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/RngNextIntInRangeBenchmark.java
new file mode 100644
index 0000000..3edae74
--- /dev/null
+++ b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/RngNextIntInRangeBenchmark.java
@@ -0,0 +1,508 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.rng.examples.jmh;
+
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.core.source32.IntProvider;
+import org.apache.commons.rng.sampling.PermutationSampler;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OperationsPerInvocation;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Executes benchmark to compare the speed of random number generators to create
+ * an int value in a range.
+ */
+@BenchmarkMode(Mode.AverageTime)
+@OutputTimeUnit(TimeUnit.NANOSECONDS)
+@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
+@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
+@State(Scope.Benchmark)
+@Fork(value = 1, jvmArgs = { "-server", "-Xms128M", "-Xmx128M" })
+public class RngNextIntInRangeBenchmark {
+    /** The value. Must NOT be final to prevent JVM optimisation! */
+    private int intValue;
+
+    /**
+     * The upper range for the {@code int} generation.
+     */
+    @State(Scope.Benchmark)
+    public static class IntRange {
+        /**
+         * The upper range for the {@code int} generation.
+         *
+         * <p>Note that the while loop uses a rejection algorithm. From the Javadoc for
+         * java.util.Random:</p>
+         *
+         * <pre>
+         * "The probability of a value being rejected depends on n. The
+         * worst case is n=2^30+1, for which the probability of a reject is 1/2,
+         * and the expected number of iterations before the loop terminates is 2."
+         * </pre>
+         */
+        @Param({"16", "17", "256", "257", "4096", "4097",
+                // Worst case power-of-2: (1 << 30)
+                "1073741824",
+                // Worst case: (1 << 30) + 1
+                "1073741825", })
+        private int n;
+
+        /**
+         * Gets the upper bound {@code n}.
+         *
+         * @return the upper bound
+         */
+        public int getN() {
+            return n;
+        }
+    }
+
+    /**
+     * The data used for the shuffle benchmark.
+     */
+    @State(Scope.Benchmark)
+    public static class IntData {
+        /**
+         * The size of the data.
+         */
+        @Param({ "4", "16", "256", "4096", "16384" })
+        private int size;
+
+        /** The data. */
+        private int[] data;
+
+        /**
+         * Gets the data.
+         *
+         * @return the data
+         */
+        public int[] getData() {
+            return data;
+        }
+
+        /**
+         * Create the data.
+         */
+        @Setup
+        public void setup() {
+            data = PermutationSampler.natural(size);
+        }
+    }
+
+    /**
+     * The source generator.
+     */
+    @State(Scope.Benchmark)
+    public static class Source {
+        /**
+         * The name of the generator.
+         */
+        @Param({ "jdk", "jdkPow2", "lemire", "lemirePow2", "lemire31", "lemire31Pow2"})
+        private String name;
+
+        /** The random generator. */
+        private UniformRandomProvider rng;
+
+        /**
+         * Gets the random generator.
+         *
+         * @return the generator
+         */
+        public UniformRandomProvider getRng() {
+            return rng;
+        }
+
+        /** Create the generator. */
+        @Setup
+        public void setup() {
+            final long seed = ThreadLocalRandom.current().nextLong();
+            if ("jdk".equals(name)) {
+                rng = new JDK(seed);
+            } else if ("jdkPow2".equals(name)) {
+                rng = new JDKPow2(seed);
+            } else if ("lemire".equals(name)) {
+                rng = new Lemire(seed);
+            } else if ("lemirePow2".equals(name)) {
+                rng = new LemirePow2(seed);
+            } else if ("lemire31".equals(name)) {
+                rng = new Lemire31(seed);
+            } else if ("lemire31Pow2".equals(name)) {
+                rng = new Lemire31Pow2(seed);
+            }
+        }
+    }
+
+    /**
+     * Implement the SplitMix algorithm from {@link java.util.SplittableRandom
+     * SplittableRandom} to output 32-bit int values.
+     *
+     * <p>This is a base generator to test nextInt(int) methods.
+     */
+    abstract static class SplitMix32 extends IntProvider {
+        /**
+         * The golden ratio, phi, scaled to 64-bits and rounded to odd.
+         */
+        private static final long GOLDEN_RATIO = 0x9e3779b97f4a7c15L;
+
+        /** The state. */
+        protected long state;
+
+        /**
+         * Create a new instance.
+         *
+         * @param seed the seed
+         */
+        SplitMix32(long seed) {
+            state = seed;
+        }
+
+        @Override
+        public int next() {
+            long key = state += GOLDEN_RATIO;
+            // 32 high bits of Stafford variant 4 mix64 function as int:
+            // http://zimbry.blogspot.com/2011/09/better-bit-mixing-improving-on.html
+            key = (key ^ (key >>> 33)) * 0x62a9d9ed799705f5L;
+            return (int) (((key ^ (key >>> 28)) * 0xcb24d0a5c88c35b3L) >>> 32);
+        }
+
+        /**
+         * Check the value is strictly positive.
+         *
+         * @param n the value
+         */
+        void checkStrictlyPositive(int n) {
+            if (n <= 0) {
+                throw new IllegalArgumentException("not strictly positive: " + n);
+            }
+        }
+    }
+
+    /**
+     * Implement the nextInt(int) method of the JDK excluding the case for a power-of-2 range.
+     */
+    static class JDK extends SplitMix32 {
+        /**
+         * Create a new instance.
+         *
+         * @param seed the seed
+         */
+        JDK(long seed) {
+            super(seed);
+        }
+
+        @Override
+        public int nextInt(int n) {
+            checkStrictlyPositive(n);
+
+            int bits;
+            int val;
+            do {
+                bits = next() >>> 1;
+                val = bits % n;
+            } while (bits - val + n - 1 < 0);
+
+            return val;
+        }
+    }
+
+    /**
+     * Implement the nextInt(int) method of the JDK with a case for a power-of-2 range.
+     */
+    static class JDKPow2 extends SplitMix32 {
+        /**
+         * Create a new instance.
+         *
+         * @param seed the seed
+         */
+        JDKPow2(long seed) {
+            super(seed);
+        }
+
+        @Override
+        public int nextInt(int n) {
+            checkStrictlyPositive(n);
+
+            final int nm1 = n - 1;
+            if ((n & nm1) == 0) {
+                // Power of 2
+                return next() & nm1;
+            }
+
+            int bits;
+            int val;
+            do {
+                bits = next() >>> 1;
+                val = bits % n;
+            } while (bits - val + nm1 < 0);
+
+            return val;
+        }
+    }
+
+    /**
+     * Implement the nextInt(int) method of Lemire (2019).
+     *
+     * @see <a href="https://arxiv.org/abs/1805.10941SplittableRandom"> Lemire
+     * (2019): Fast Random Integer Generation in an Interval</a>
+     */
+    static class Lemire extends SplitMix32 {
+        /** 2^32. */
+        static final long POW_32 = 1L << 32;
+
+        /**
+         * Create a new instance.
+         *
+         * @param seed the seed
+         */
+        Lemire(long seed) {
+            super(seed);
+        }
+
+        @Override
+        public int nextInt(int n) {
+            checkStrictlyPositive(n);
+
+            long m = (next() & 0xffffffffL) * n;
+            long l = m & 0xffffffffL;
+            if (l < n) {
+                // 2^32 % n
+                final long t = POW_32 % n;
+                while (l < t) {
+                    m = (next() & 0xffffffffL) * n;
+                    l = m & 0xffffffffL;
+                }
+            }
+            return (int) (m >>> 32);
+        }
+    }
+
+    /**
+     * Implement the nextInt(int) method of Lemire (2019) with a case for a power-of-2 range.
+     */
+    static class LemirePow2 extends SplitMix32 {
+        /** 2^32. */
+        static final long POW_32 = 1L << 32;
+
+        /**
+         * Create a new instance.
+         *
+         * @param seed the seed
+         */
+        LemirePow2(long seed) {
+            super(seed);
+        }
+
+        @Override
+        public int nextInt(int n) {
+            checkStrictlyPositive(n);
+
+            final int nm1 = n - 1;
+            if ((n & nm1) == 0) {
+                // Power of 2
+                return next() & nm1;
+            }
+
+            long m = (next() & 0xffffffffL) * n;
+            long l = m & 0xffffffffL;
+            if (l < n) {
+                // 2^32 % n
+                final long t = POW_32 % n;
+                while (l < t) {
+                    m = (next() & 0xffffffffL) * n;
+                    l = m & 0xffffffffL;
+                }
+            }
+            return (int) (m >>> 32);
+        }
+    }
+
+    /**
+     * Implement the nextInt(int) method of Lemire (2019) modified to 31-bit arithmetic to use
+     * an int modulus operation.
+     */
+    static class Lemire31 extends SplitMix32 {
+        /** 2^32. */
+        static final long POW_32 = 1L << 32;
+
+        /**
+         * Create a new instance.
+         *
+         * @param seed the seed
+         */
+        Lemire31(long seed) {
+            super(seed);
+        }
+
+        @Override
+        public int nextInt(int n) {
+            checkStrictlyPositive(n);
+
+            long m = (nextInt() & 0x7fffffffL) * n;
+            long l = m & 0x7fffffffL;
+            if (l < n) {
+                // 2^31 % n
+                final long t = (Integer.MIN_VALUE - n) % n;
+                while (l < t) {
+                    m = (nextInt() & 0x7fffffffL) * n;
+                    l = m & 0x7fffffffL;
+                }
+            }
+            return (int) (m >>> 31);
+        }
+    }
+
+    /**
+     * Implement the nextInt(int) method of Lemire (2019) modified to 31-bit arithmetic to use
+     * an int modulus operation, with a case for a power-of-2 range.
+     */
+    static class Lemire31Pow2 extends SplitMix32 {
+        /** 2^32. */
+        static final long POW_32 = 1L << 32;
+
+        /**
+         * Create a new instance.
+         *
+         * @param seed the seed
+         */
+        Lemire31Pow2(long seed) {
+            super(seed);
+        }
+
+        @Override
+        public int nextInt(int n) {
+            checkStrictlyPositive(n);
+
+            final int nm1 = n - 1;
+            if ((n & nm1) == 0) {
+                // Power of 2
+                return next() & nm1;
+            }
+
+            long m = (nextInt() & 0x7fffffffL) * n;
+            long l = m & 0x7fffffffL;
+            if (l < n) {
+                // 2^31 % n
+                final long t = (Integer.MIN_VALUE - n) % n;
+                while (l < t) {
+                    m = (nextInt() & 0x7fffffffL) * n;
+                    l = m & 0x7fffffffL;
+                }
+            }
+            return (int) (m >>> 31);
+        }
+    }
+
+    /**
+     * Baseline for a JMH method call returning an {@code int}.
+     *
+     * @return the value
+     */
+    @Benchmark
+    public int baselineInt() {
+        return intValue;
+    }
+
+    /**
+     * Exercise the {@link UniformRandomProvider#nextInt()} method.
+     *
+     * @param range the range
+     * @param source Source of randomness.
+     * @return the int
+     */
+    @Benchmark
+    public int nextIntN(IntRange range, Source source) {
+        return source.getRng().nextInt(range.getN());
+    }
+
+    /**
+     * Exercise the {@link UniformRandomProvider#nextInt()} method in a loop.
+     *
+     * @param range the range
+     * @param source Source of randomness.
+     * @return the int
+     */
+    @Benchmark
+    @OperationsPerInvocation(65536)
+    public int nextIntNloop65536(IntRange range, Source source) {
+        int sum = 0;
+        for (int i = 0; i < 65536; i++) {
+            sum += source.getRng().nextInt(range.getN());
+        }
+        return sum;
+    }
+
+    /**
+     * Exercise the {@link UniformRandomProvider#nextInt(int)} method by shuffling
+     * data.
+     *
+     * @param data the data
+     * @param source Source of randomness.
+     * @return the shuffle data
+     */
+    @Benchmark
+    public int[] shuffle(IntData data, Source source) {
+        final int[] array = data.getData();
+        shuffle(array, source.getRng());
+        return array;
+    }
+
+    /**
+     * Perform a Fischer-Yates shuffle.
+     *
+     * @param array the array
+     * @param rng the random generator
+     */
+    private static void shuffle(int[] array, UniformRandomProvider rng) {
+        for (int i = array.length - 1; i > 0; i--) {
+            // Swap index with any position down to 0
+            final int j = rng.nextInt(i);
+            final int tmp = array[j];
+            array[j] = array[i];
+            array[i] = tmp;
+        }
+    }
+
+    /**
+     * Exercise the {@link UniformRandomProvider#nextInt(int)} method by creating
+     * random indices for shuffling data.
+     *
+     * @param data the data
+     * @param source Source of randomness.
+     * @return the sum
+     */
+    @Benchmark
+    public int pseudoShuffle(IntData data, Source source) {
+        int sum = 0;
+        for (int i = data.getData().length - 1; i > 0; i--) {
+            sum += source.getRng().nextInt(i);
+        }
+        return sum;
+    }
+}