You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by tn...@apache.org on 2014/09/30 21:16:31 UTC

git commit: [MATH-1152] Improved performance of EnumeratedDistribution#sample(). Thanks to Andras Sereny.

Repository: commons-math
Updated Branches:
  refs/heads/master 97d32b14e -> 97accb47d


[MATH-1152] Improved performance of EnumeratedDistribution#sample(). Thanks to Andras Sereny.


Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/97accb47
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/97accb47
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/97accb47

Branch: refs/heads/master
Commit: 97accb47de63ee5063eda23641c6017e29ab81d7
Parents: 97d32b1
Author: Thomas Neidhart <th...@gmail.com>
Authored: Tue Sep 30 21:16:07 2014 +0200
Committer: Thomas Neidhart <th...@gmail.com>
Committed: Tue Sep 30 21:16:07 2014 +0200

----------------------------------------------------------------------
 src/changes/changes.xml                         |  4 +++
 .../distribution/EnumeratedDistribution.java    | 31 +++++++++++++++-----
 2 files changed, 28 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-math/blob/97accb47/src/changes/changes.xml
----------------------------------------------------------------------
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index c64fb58..fdb2bd4 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -73,6 +73,10 @@ Users are encouraged to upgrade to this version as this release not
   2. A few methods in the FastMath class are in fact slower that their
   counterpart in either Math or StrictMath (cf. MATH-740 and MATH-901).
 ">
+      <action dev="tn" type="fix" issue="MATH-1152" due-to="Andras Sereny">
+        Improved performance of "EnumeratedDistribution#sample()" by caching
+        the cumulative probabilities and using binary rather than a linear search.
+      </action>
       <action dev="tn" type="fix" issue="MATH-1148" due-to="Guillaume Marceau">
         "MonotoneChain" did not take the tolerance factor into account when
         sorting the input points. In case of collinear points this could result

http://git-wip-us.apache.org/repos/asf/commons-math/blob/97accb47/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java b/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java
index 5117c2a..e95098c 100644
--- a/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java
+++ b/src/main/java/org/apache/commons/math3/distribution/EnumeratedDistribution.java
@@ -19,6 +19,7 @@ package org.apache.commons.math3.distribution;
 import java.io.Serializable;
 import java.lang.reflect.Array;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 
 import org.apache.commons.math3.exception.MathArithmeticException;
@@ -64,6 +65,7 @@ public class EnumeratedDistribution<T> implements Serializable {
      * List of random variable values.
      */
     private final List<T> singletons;
+
     /**
      * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
      * probability[i] is the probability that a random variable following this distribution takes
@@ -72,6 +74,11 @@ public class EnumeratedDistribution<T> implements Serializable {
     private final double[] probabilities;
 
     /**
+     * Cumulative probabilities, cached to speed up sampling.
+     */
+    private final double[] cumulativeProbabilities;
+
+    /**
      * Create an enumerated distribution using the given probability mass function
      * enumeration.
      *
@@ -123,6 +130,13 @@ public class EnumeratedDistribution<T> implements Serializable {
         }
 
         probabilities = MathArrays.normalizeArray(probs, 1.0);
+
+        cumulativeProbabilities = new double[probabilities.length];
+        double sum = 0;
+        for (int i = 0; i < probabilities.length; i++) {
+            sum += probabilities[i];
+            cumulativeProbabilities[i] = sum;
+        }
     }
 
     /**
@@ -186,18 +200,21 @@ public class EnumeratedDistribution<T> implements Serializable {
      */
     public T sample() {
         final double randomValue = random.nextDouble();
-        double sum = 0;
 
-        for (int i = 0; i < probabilities.length; i++) {
-            sum += probabilities[i];
-            if (randomValue < sum) {
-                return singletons.get(i);
+        int index = Arrays.binarySearch(cumulativeProbabilities, randomValue);
+        if (index < 0) {
+            index = -index-1;
+        }
+
+        if (index >= 0 && index < probabilities.length) {
+            if (randomValue < cumulativeProbabilities[index]) {
+                return singletons.get(index);
             }
         }
 
         /* This should never happen, but it ensures we will return a correct
-         * object in case the loop above has some floating point inequality
-         * problem on the final iteration. */
+         * object in case there is some floating point inequality problem
+         * wrt the cumulative probabilities. */
         return singletons.get(singletons.size() - 1);
     }