You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2012/11/27 19:16:22 UTC

svn commit: r1414300 - in /mahout/trunk/core/src/main/java/org/apache/mahout: cf/taste/impl/common/SamplingLongPrimitiveIterator.java common/iterator/SamplingIterator.java

Author: srowen
Date: Tue Nov 27 18:16:21 2012
New Revision: 1414300

URL: http://svn.apache.org/viewvc?rev=1414300&view=rev
Log:
Improve random sampling from an iterator by choosing a number of elements to skip from a negative binomial distribution instead of actually conducting a lot of trials. Much faster when sampling rate is near 0.

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java
    mahout/trunk/core/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java?rev=1414300&r1=1414299&r2=1414300&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java Tue Nov 27 18:16:21 2012
@@ -18,9 +18,9 @@
 package org.apache.mahout.cf.taste.impl.common;
 
 import java.util.NoSuchElementException;
-import java.util.Random;
 
-import org.apache.mahout.common.RandomUtils;
+import org.apache.commons.math.MathException;
+import org.apache.commons.math.distribution.PascalDistributionImpl;
 
 /**
  * Wraps a {@link LongPrimitiveIterator} and returns only some subset of the elements that it would,
@@ -28,16 +28,15 @@ import org.apache.mahout.common.RandomUt
  */
 public final class SamplingLongPrimitiveIterator extends AbstractLongPrimitiveIterator {
   
-  private final Random random;
+  private final PascalDistributionImpl geometricDistribution;
   private final LongPrimitiveIterator delegate;
-  private final double samplingRate;
   private long next;
   private boolean hasNext;
   
   public SamplingLongPrimitiveIterator(LongPrimitiveIterator delegate, double samplingRate) {
-    random = RandomUtils.getRandom();
+    // Geometric distribution is special case of negative binomial (aka Pascal) with r=1:
+    geometricDistribution = new PascalDistributionImpl(1, samplingRate);
     this.delegate = delegate;
-    this.samplingRate = samplingRate;
     this.hasNext = true;
     doNext();
   }
@@ -66,14 +65,13 @@ public final class SamplingLongPrimitive
   }
   
   private void doNext() {
-    int toSkip = 0;
-    while (random.nextDouble() >= samplingRate) {
-      toSkip++;
-    }
-    // Really, would be nicer to select value from geometric distribution, for small values of samplingRate
-    if (toSkip > 0) {
-      delegate.skip(toSkip);
+    int toSkip;
+    try {
+      toSkip = geometricDistribution.sample();
+    } catch (MathException e) {
+      throw new IllegalStateException(e);
     }
+    delegate.skip(toSkip);
     if (delegate.hasNext()) {
       next = delegate.next();
     } else {
@@ -91,7 +89,15 @@ public final class SamplingLongPrimitive
   
   @Override
   public void skip(int n) {
-    delegate.skip((int) (n / samplingRate)); // Kind of an approximation, but this is expected skip
+    int toSkip = 0;
+    try {
+      for (int i = 0; i < n; i++) {
+        toSkip += geometricDistribution.sample();
+      }
+    } catch (MathException e) {
+      throw new IllegalStateException(e);
+    }
+    delegate.skip(toSkip);
     if (delegate.hasNext()) {
       next = delegate.next();
     } else {

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java?rev=1414300&r1=1414299&r2=1414300&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java Tue Nov 27 18:16:21 2012
@@ -18,11 +18,11 @@
 package org.apache.mahout.common.iterator;
 
 import java.util.Iterator;
-import java.util.Random;
 
 import com.google.common.collect.AbstractIterator;
+import org.apache.commons.math.MathException;
+import org.apache.commons.math.distribution.PascalDistributionImpl;
 import org.apache.mahout.cf.taste.impl.common.SkippingIterator;
-import org.apache.mahout.common.RandomUtils;
 
 /**
  * Wraps an {@link Iterator} and returns only some subset of the elements that it would, as determined by a
@@ -30,37 +30,35 @@ import org.apache.mahout.common.RandomUt
  */
 public final class SamplingIterator<T> extends AbstractIterator<T> {
   
-  private final Random random;
+  private final PascalDistributionImpl geometricDistribution;
   private final Iterator<? extends T> delegate;
-  private final double samplingRate;
-  
+
   public SamplingIterator(Iterator<? extends T> delegate, double samplingRate) {
-    random = RandomUtils.getRandom();
+    // Geometric distribution is special case of negative binomial (aka Pascal) with r=1:
+    geometricDistribution = new PascalDistributionImpl(1, samplingRate);
     this.delegate = delegate;
-    this.samplingRate = samplingRate;
   }
 
   @Override
   protected T computeNext() {
+    int toSkip;
+    try {
+      toSkip = geometricDistribution.sample();
+    } catch (MathException e) {
+      throw new IllegalStateException(e);
+    }
     if (delegate instanceof SkippingIterator<?>) {
       SkippingIterator<? extends T> skippingDelegate = (SkippingIterator<? extends T>) delegate;
-      int toSkip = 0;
-      while (random.nextDouble() >= samplingRate) {
-        toSkip++;
-      }
-      // Really, would be nicer to select value from geometric distribution, for small values of samplingRate
-      if (toSkip > 0) {
-        skippingDelegate.skip(toSkip);
-      }
+      skippingDelegate.skip(toSkip);
       if (skippingDelegate.hasNext()) {
         return skippingDelegate.next();
       }
     } else {
-      while (delegate.hasNext()) {
-        T delegateNext = delegate.next();
-        if (random.nextDouble() < samplingRate) {
-          return delegateNext;
-        }
+      for (int i = 0; i < toSkip && delegate.hasNext(); i++) {
+        delegate.next();
+      }
+      if (delegate.hasNext()) {
+        return delegate.next();
       }
     }
     return endOfData();