You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2012/11/27 19:16:22 UTC
svn commit: r1414300 - in /mahout/trunk/core/src/main/java/org/apache/mahout:
cf/taste/impl/common/SamplingLongPrimitiveIterator.java
common/iterator/SamplingIterator.java
Author: srowen
Date: Tue Nov 27 18:16:21 2012
New Revision: 1414300
URL: http://svn.apache.org/viewvc?rev=1414300&view=rev
Log:
Improve random sampling from an iterator by choosing a number of elements to skip from a negative binomial distribution instead of actually conducting a lot of trials. Much faster when sampling rate is near 0.
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java
mahout/trunk/core/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java?rev=1414300&r1=1414299&r2=1414300&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java Tue Nov 27 18:16:21 2012
@@ -18,9 +18,9 @@
package org.apache.mahout.cf.taste.impl.common;
import java.util.NoSuchElementException;
-import java.util.Random;
-import org.apache.mahout.common.RandomUtils;
+import org.apache.commons.math.MathException;
+import org.apache.commons.math.distribution.PascalDistributionImpl;
/**
* Wraps a {@link LongPrimitiveIterator} and returns only some subset of the elements that it would,
@@ -28,16 +28,15 @@ import org.apache.mahout.common.RandomUt
*/
public final class SamplingLongPrimitiveIterator extends AbstractLongPrimitiveIterator {
- private final Random random;
+ private final PascalDistributionImpl geometricDistribution;
private final LongPrimitiveIterator delegate;
- private final double samplingRate;
private long next;
private boolean hasNext;
public SamplingLongPrimitiveIterator(LongPrimitiveIterator delegate, double samplingRate) {
- random = RandomUtils.getRandom();
+ // Geometric distribution is special case of negative binomial (aka Pascal) with r=1:
+ geometricDistribution = new PascalDistributionImpl(1, samplingRate);
this.delegate = delegate;
- this.samplingRate = samplingRate;
this.hasNext = true;
doNext();
}
@@ -66,14 +65,13 @@ public final class SamplingLongPrimitive
}
private void doNext() {
- int toSkip = 0;
- while (random.nextDouble() >= samplingRate) {
- toSkip++;
- }
- // Really, would be nicer to select value from geometric distribution, for small values of samplingRate
- if (toSkip > 0) {
- delegate.skip(toSkip);
+ int toSkip;
+ try {
+ toSkip = geometricDistribution.sample();
+ } catch (MathException e) {
+ throw new IllegalStateException(e);
}
+ delegate.skip(toSkip);
if (delegate.hasNext()) {
next = delegate.next();
} else {
@@ -91,7 +89,15 @@ public final class SamplingLongPrimitive
@Override
public void skip(int n) {
- delegate.skip((int) (n / samplingRate)); // Kind of an approximation, but this is expected skip
+ int toSkip = 0;
+ try {
+ for (int i = 0; i < n; i++) {
+ toSkip += geometricDistribution.sample();
+ }
+ } catch (MathException e) {
+ throw new IllegalStateException(e);
+ }
+ delegate.skip(toSkip);
if (delegate.hasNext()) {
next = delegate.next();
} else {
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java?rev=1414300&r1=1414299&r2=1414300&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java Tue Nov 27 18:16:21 2012
@@ -18,11 +18,11 @@
package org.apache.mahout.common.iterator;
import java.util.Iterator;
-import java.util.Random;
import com.google.common.collect.AbstractIterator;
+import org.apache.commons.math.MathException;
+import org.apache.commons.math.distribution.PascalDistributionImpl;
import org.apache.mahout.cf.taste.impl.common.SkippingIterator;
-import org.apache.mahout.common.RandomUtils;
/**
* Wraps an {@link Iterator} and returns only some subset of the elements that it would, as determined by a
@@ -30,37 +30,35 @@ import org.apache.mahout.common.RandomUt
*/
public final class SamplingIterator<T> extends AbstractIterator<T> {
- private final Random random;
+ private final PascalDistributionImpl geometricDistribution;
private final Iterator<? extends T> delegate;
- private final double samplingRate;
-
+
public SamplingIterator(Iterator<? extends T> delegate, double samplingRate) {
- random = RandomUtils.getRandom();
+ // Geometric distribution is special case of negative binomial (aka Pascal) with r=1:
+ geometricDistribution = new PascalDistributionImpl(1, samplingRate);
this.delegate = delegate;
- this.samplingRate = samplingRate;
}
@Override
protected T computeNext() {
+ int toSkip;
+ try {
+ toSkip = geometricDistribution.sample();
+ } catch (MathException e) {
+ throw new IllegalStateException(e);
+ }
if (delegate instanceof SkippingIterator<?>) {
SkippingIterator<? extends T> skippingDelegate = (SkippingIterator<? extends T>) delegate;
- int toSkip = 0;
- while (random.nextDouble() >= samplingRate) {
- toSkip++;
- }
- // Really, would be nicer to select value from geometric distribution, for small values of samplingRate
- if (toSkip > 0) {
- skippingDelegate.skip(toSkip);
- }
+ skippingDelegate.skip(toSkip);
if (skippingDelegate.hasNext()) {
return skippingDelegate.next();
}
} else {
- while (delegate.hasNext()) {
- T delegateNext = delegate.next();
- if (random.nextDouble() < samplingRate) {
- return delegateNext;
- }
+ for (int i = 0; i < toSkip && delegate.hasNext(); i++) {
+ delegate.next();
+ }
+ if (delegate.hasNext()) {
+ return delegate.next();
}
}
return endOfData();