You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by er...@apache.org on 2016/05/31 09:15:49 UTC

[2/2] [math] MATH-1370

MATH-1370

Cache some data for improved performance.

Thanks to Ryan Gaffney.


Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/0cc22651
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/0cc22651
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/0cc22651

Branch: refs/heads/feature-MATH-1370
Commit: 0cc22651bbd5db9fefac5f3114bb881aadca0d0c
Parents: 69ed91c
Author: Gilles <gi...@harfang.homelinux.org>
Authored: Tue May 31 11:01:08 2016 +0200
Committer: Gilles <gi...@harfang.homelinux.org>
Committed: Tue May 31 11:01:08 2016 +0200

----------------------------------------------------------------------
 .../distribution/EnumeratedDistribution.java    | 43 ++++++++------
 .../EnumeratedDistributionTest.java             | 61 ++++++++++++++++++++
 2 files changed, 87 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-math/blob/0cc22651/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java b/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java
index 88ce037..3fa6279 100644
--- a/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java
+++ b/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java
@@ -20,7 +20,9 @@ import java.io.Serializable;
 import java.lang.reflect.Array;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 
 import org.apache.commons.math4.exception.MathArithmeticException;
 import org.apache.commons.math4.exception.NotANumberException;
@@ -59,11 +61,16 @@ public class EnumeratedDistribution<T> implements Serializable {
     private final List<T> singletons;
     /**
      * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
-     * probability[i] is the probability that a random variable following this distribution takes
+     * probabilities[i] is the probability that a random variable following this distribution takes
      * the value singletons[i].
      */
     private final double[] probabilities;
     /**
+     * Probabilities of aggregated distinct random variables, cached to speed up probability lookup.
+     */
+    private final Map<T, Double> massPoints;
+
+    /**
      * Cumulative probabilities, cached to speed up sampling.
      */
     private final double[] cumulativeProbabilities;
@@ -92,7 +99,7 @@ public class EnumeratedDistribution<T> implements Serializable {
             singletons.add(sample.getKey());
             final double p = sample.getValue();
             if (p < 0) {
-                throw new NotPositiveException(sample.getValue());
+                throw new NotPositiveException(p);
             }
             if (Double.isInfinite(p)) {
                 throw new NotFiniteNumberException(p);
@@ -105,11 +112,23 @@ public class EnumeratedDistribution<T> implements Serializable {
 
         probabilities = MathArrays.normalizeArray(probs, 1.0);
 
+        massPoints = new HashMap<T, Double>();
         cumulativeProbabilities = new double[probabilities.length];
         double sum = 0;
         for (int i = 0; i < probabilities.length; i++) {
-            sum += probabilities[i];
+            double probability = probabilities[i];
+
+            sum += probability;
             cumulativeProbabilities[i] = sum;
+
+            T randomVariable = singletons.get(i);
+            final double existingProbability;
+            if (massPoints.containsKey(randomVariable)) {
+                existingProbability = massPoints.get(randomVariable);
+            } else {
+                existingProbability = 0.0;
+            }
+            massPoints.put(randomVariable, existingProbability + probability);
         }
     }
 
@@ -120,22 +139,14 @@ public class EnumeratedDistribution<T> implements Serializable {
      * distribution.</p>
      *
      * <p>Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)},
-     * or both are null, then {@code probability(x1) = probability(x2)}.</p>
+     * or both are null, then {@code probability(x1) == probability(x2)}.</p>
      *
      * @param x the point at which the PMF is evaluated
      * @return the value of the probability mass function at {@code x}
      */
     double probability(final T x) {
-        double probability = 0;
-
-        for (int i = 0; i < probabilities.length; i++) {
-            if ((x == null && singletons.get(i) == null) ||
-                (x != null && x.equals(singletons.get(i)))) {
-                probability += probabilities[i];
-            }
-        }
-
-        return probability;
+	final Double p = massPoints.get(x);
+	return p == null ? 0 : p.doubleValue();
     }
 
     /**
@@ -195,9 +206,7 @@ public class EnumeratedDistribution<T> implements Serializable {
                 index = -index - 1;
             }
 
-            if (index >= 0 &&
-                index < probabilities.length &&
-                randomValue < cumulativeProbabilities[index]) {
+            if (randomValue < cumulativeProbabilities[index]) {
                 return singletons.get(index);
             }
 

http://git-wip-us.apache.org/repos/asf/commons-math/blob/0cc22651/src/test/java/org/apache/commons/math4/distribution/EnumeratedDistributionTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/commons/math4/distribution/EnumeratedDistributionTest.java b/src/test/java/org/apache/commons/math4/distribution/EnumeratedDistributionTest.java
new file mode 100644
index 0000000..c1756b0
--- /dev/null
+++ b/src/test/java/org/apache/commons/math4/distribution/EnumeratedDistributionTest.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math4.distribution;
+
+import static org.junit.Assert.assertEquals;
+
+import org.apache.commons.math4.util.Pair;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.List;
+
+
+/**
+ * Test class for {@link EnumeratedDistribution}.
+ */
+public class EnumeratedDistributionTest {
+    @Test
+    public void testProbability() {
+        final String[] values = {"car", "bike", null};
+        final List<Pair<String, Double>> pmf = Arrays.asList(
+            new Pair<String, Double>(values[0], 0.1),
+            new Pair<String, Double>(values[1], 0.3),
+            new Pair<String, Double>(values[1], 0.2),
+            new Pair<String, Double>(values[2], 0.2),
+            new Pair<String, Double>(values[2], 0.2)
+        );
+        final EnumeratedDistribution<String> distribution = new EnumeratedDistribution<String>(pmf);
+        assertEquals(0.1, distribution.probability(values[0]), 0);
+        assertEquals(0.5, distribution.probability(values[1]), 0);
+        assertEquals(0.4, distribution.probability(values[2]), 0);
+    }
+
+    @Test
+    public void testGetPmf() {
+	final String s = "bike";
+        final List<Pair<String, Double>> pmf = Arrays.asList(
+            new Pair<String, Double>(s, 0.1),
+            new Pair<String, Double>(s, 0.3),
+            new Pair<String, Double>(null, 0.2),
+            new Pair<String, Double>(s, 0.2),
+            new Pair<String, Double>(null, 0.2)
+        );
+        final EnumeratedDistribution<String> distribution = new EnumeratedDistribution<String>(pmf);
+        assertEquals(pmf, distribution.getPmf());
+    }
+}