You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@crunch.apache.org by jw...@apache.org on 2013/03/13 17:46:32 UTC
git commit: CRUNCH-178: Add reservoir sampling functions to
lib.Sample and make the reservoir and regular sampling APIs consistent.
Updated Branches:
refs/heads/master 1138d3855 -> 96f5e9f8c
CRUNCH-178: Add reservoir sampling functions to lib.Sample and make the reservoir and
regular sampling APIs consistent.
Project: http://git-wip-us.apache.org/repos/asf/crunch/repo
Commit: http://git-wip-us.apache.org/repos/asf/crunch/commit/96f5e9f8
Tree: http://git-wip-us.apache.org/repos/asf/crunch/tree/96f5e9f8
Diff: http://git-wip-us.apache.org/repos/asf/crunch/diff/96f5e9f8
Branch: refs/heads/master
Commit: 96f5e9f8cbaf387c93204db2e9d15430154124cb
Parents: 1138d38
Author: Josh Wills <jw...@apache.org>
Authored: Wed Mar 6 15:09:28 2013 -0800
Committer: Josh Wills <jw...@apache.org>
Committed: Wed Mar 13 01:04:09 2013 -0700
----------------------------------------------------------------------
.../main/java/org/apache/crunch/lib/Sample.java | 197 ++++++++++++---
.../java/org/apache/crunch/lib/SampleUtils.java | 161 ++++++++++++
.../java/org/apache/crunch/lib/SampleTest.java | 37 +++-
3 files changed, 356 insertions(+), 39 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/crunch/blob/96f5e9f8/crunch/src/main/java/org/apache/crunch/lib/Sample.java
----------------------------------------------------------------------
diff --git a/crunch/src/main/java/org/apache/crunch/lib/Sample.java b/crunch/src/main/java/org/apache/crunch/lib/Sample.java
index 5be2292..be75ae2 100644
--- a/crunch/src/main/java/org/apache/crunch/lib/Sample.java
+++ b/crunch/src/main/java/org/apache/crunch/lib/Sample.java
@@ -17,51 +17,37 @@
*/
package org.apache.crunch.lib;
-import java.util.Random;
-import org.apache.crunch.FilterFn;
+import org.apache.crunch.MapFn;
import org.apache.crunch.PCollection;
import org.apache.crunch.PTable;
import org.apache.crunch.Pair;
+import org.apache.crunch.lib.SampleUtils.ReservoirSampleFn;
+import org.apache.crunch.lib.SampleUtils.SampleFn;
+import org.apache.crunch.lib.SampleUtils.WRSCombineFn;
+import org.apache.crunch.types.PTableType;
+import org.apache.crunch.types.PType;
+import org.apache.crunch.types.PTypeFamily;
-import com.google.common.base.Preconditions;
-
+/**
+ * Methods for performing random sampling in a distributed fashion, either by accepting each
+ * record in a {@code PCollection} with an independent probability in order to sample some
+ * fraction of the overall data set, or by using reservoir sampling in order to pull a uniform
+ * or weighted sample of fixed size from a {@code PCollection} of an unknown size. For more details
+ * on the reservoir sampling algorithms used by this library, see the A-ES algorithm described in
+ * <a href="http://arxiv.org/pdf/1012.0256.pdf">Efraimidis (2012)</a>.
+ */
public class Sample {
- private static class SamplerFn<S> extends FilterFn<S> {
-
- private final long seed;
- private final double acceptanceProbability;
- private transient Random r;
-
- public SamplerFn(long seed, double acceptanceProbability) {
- Preconditions.checkArgument(0.0 < acceptanceProbability && acceptanceProbability < 1.0);
- this.seed = seed;
- this.acceptanceProbability = acceptanceProbability;
- }
-
- @Override
- public void initialize() {
- if (r == null) {
- r = new Random(seed);
- }
- }
-
- @Override
- public boolean accept(S input) {
- return r.nextDouble() < acceptanceProbability;
- }
- }
-
/**
* Output records from the given {@code PCollection} with the given probability.
*
* @param input The {@code PCollection} to sample from
- * @param probability The probability (0.0 < p < 1.0)
+ * @param probability The probability (0.0 < p %lt; 1.0)
* @return The output {@code PCollection} created from sampling
*/
public static <S> PCollection<S> sample(PCollection<S> input, double probability) {
- return sample(input, System.currentTimeMillis(), probability);
+ return sample(input, null, probability);
}
/**
@@ -69,26 +55,163 @@ public class Sample {
* testing.
*
* @param input The {@code PCollection} to sample from
- * @param seed The seed
- * @param probability The probability (0.0 < p < 1.0)
+ * @param seed The seed for the random number generator
+ * @param probability The probability (0.0 < p < 1.0)
* @return The output {@code PCollection} created from sampling
*/
- public static <S> PCollection<S> sample(PCollection<S> input, long seed, double probability) {
+ public static <S> PCollection<S> sample(PCollection<S> input, Long seed, double probability) {
String stageName = String.format("sample(%.2f)", probability);
- return input.parallelDo(stageName, new SamplerFn<S>(seed, probability), input.getPType());
+ return input.parallelDo(stageName, new SampleFn<S>(probability, seed), input.getPType());
}
/**
* A {@code PTable<K, V>} analogue of the {@code sample} function.
+ *
+ * @param input The {@code PTable} to sample from
+ * @param probability The probability (0.0 < p < 1.0)
+ * @return The output {@code PTable} created from sampling
*/
public static <K, V> PTable<K, V> sample(PTable<K, V> input, double probability) {
return PTables.asPTable(sample((PCollection<Pair<K, V>>) input, probability));
}
/**
- * A {@code PTable<K, V>} analogue of the {@code sample} function.
+ * A {@code PTable<K, V>} analogue of the {@code sample} function, with the seed argument
+ * exposed for testing purposes.
+ *
+ * @param input The {@code PTable} to sample from
+ * @param seed The seed for the random number generator
+ * @param probability The probability (0.0 < p < 1.0)
+ * @return The output {@code PTable} created from sampling
*/
- public static <K, V> PTable<K, V> sample(PTable<K, V> input, long seed, double probability) {
+ public static <K, V> PTable<K, V> sample(PTable<K, V> input, Long seed, double probability) {
return PTables.asPTable(sample((PCollection<Pair<K, V>>) input, seed, probability));
}
+
+ /**
+ * Select a fixed number of elements from the given {@code PCollection} with each element
+ * equally likely to be included in the sample.
+ *
+ * @param input The input data
+ * @param sampleSize The number of elements to select
+ * @return A {@code PCollection} made up of the sampled elements
+ */
+ public static <T> PCollection<T> reservoirSample(
+ PCollection<T> input,
+ int sampleSize) {
+ return reservorSample(input, sampleSize, null);
+ }
+
+ /**
+ * A version of the reservoir sampling algorithm that uses a given seed, primarily for
+ * testing purposes.
+ *
+ * @param input The input data
+ * @param sampleSize The number of elements to select
+ * @param seed The test seed
+ * @return A {@code PCollection} made up of the sampled elements
+
+ */
+ public static <T> PCollection<T> reservorSample(
+ PCollection<T> input,
+ int sampleSize,
+ Long seed) {
+ PTypeFamily ptf = input.getTypeFamily();
+ PType<Pair<T, Integer>> ptype = ptf.pairs(input.getPType(), ptf.ints());
+ return weightedReservoirSample(
+ input.parallelDo(new MapFn<T, Pair<T, Integer>>() {
+ public Pair<T, Integer> map(T t) { return Pair.of(t, 1); }
+ }, ptype),
+ sampleSize,
+ seed);
+ }
+
+ /**
+ * Selects a weighted sample of the elements of the given {@code PCollection}, where the second term in
+ * the input {@code Pair} is a numerical weight.
+ *
+ * @param input the weighted observations
+ * @param sampleSize The number of elements to select
+ * @return A random sample of the given size that respects the weighting values
+ */
+ public static <T, N extends Number> PCollection<T> weightedReservoirSample(
+ PCollection<Pair<T, N>> input,
+ int sampleSize) {
+ return weightedReservoirSample(input, sampleSize, null);
+ }
+
+ /**
+ * The weighted reservoir sampling function with the seed term exposed for testing purposes.
+ *
+ * @param input the weighted observations
+ * @param sampleSize The number of elements to select
+ * @param seed The test seed
+ * @return A random sample of the given size that respects the weighting values
+ */
+ public static <T, N extends Number> PCollection<T> weightedReservoirSample(
+ PCollection<Pair<T, N>> input,
+ int sampleSize,
+ Long seed) {
+ PTypeFamily ptf = input.getTypeFamily();
+ PTable<Integer, Pair<T, N>> groupedIn = input.parallelDo(
+ new MapFn<Pair<T, N>, Pair<Integer, Pair<T, N>>>() {
+ @Override
+ public Pair<Integer, Pair<T, N>> map(Pair<T, N> p) {
+ return Pair.of(0, p);
+ }
+ }, ptf.tableOf(ptf.ints(), input.getPType()));
+ int[] ss = new int[] { sampleSize };
+ return groupedWeightedReservoirSample(groupedIn, ss, seed)
+ .parallelDo(new MapFn<Pair<Integer, T>, T>() {
+ @Override
+ public T map(Pair<Integer, T> p) {
+ return p.second();
+ }
+ }, (PType<T>) input.getPType().getSubTypes().get(0));
+ }
+
+ /**
+ * The most general purpose of the weighted reservoir sampling patterns that allows us to choose
+ * a random sample of elements for each of N input groups.
+ *
+ * @param input A {@code PTable} with the key a group ID and the value a weighted observation in that group
+ * @param sampleSizes An array of length N, with each entry is the number of elements to include in that group
+ * @return A {@code PCollection} of the sampled elements for each of the groups
+ */
+
+ public static <T, N extends Number> PCollection<Pair<Integer, T>> groupedWeightedReservoirSample(
+ PTable<Integer, Pair<T, N>> input,
+ int[] sampleSizes) {
+ return groupedWeightedReservoirSample(input, sampleSizes, null);
+ }
+
+ /**
+ * Same as the other groupedWeightedReservoirSample method, but include a seed for testing
+ * purposes.
+ *
+ * @param input A {@code PTable} with the key a group ID and the value a weighted observation in that group
+ * @param sampleSizes An array of length N, with each entry is the number of elements to include in that group
+ * @param seed The test seed
+ * @return A {@code PCollection} of the sampled elements for each of the groups
+ */
+ public static <T, N extends Number> PCollection<Pair<Integer, T>> groupedWeightedReservoirSample(
+ PTable<Integer, Pair<T, N>> input,
+ int[] sampleSizes,
+ Long seed) {
+ PTypeFamily ptf = input.getTypeFamily();
+ PType<T> ttype = (PType<T>) input.getPTableType().getValueType().getSubTypes().get(0);
+ PTableType<Integer, Pair<Double, T>> ptt = ptf.tableOf(ptf.ints(),
+ ptf.pairs(ptf.doubles(), ttype));
+
+ return input.parallelDo(new ReservoirSampleFn<T, N>(sampleSizes, seed), ptt)
+ .groupByKey(1)
+ .combineValues(new WRSCombineFn<T>(sampleSizes))
+ .parallelDo(new MapFn<Pair<Integer, Pair<Double, T>>, Pair<Integer, T>>() {
+ @Override
+ public Pair<Integer, T> map(Pair<Integer, Pair<Double, T>> p) {
+ return Pair.of(p.first(), p.second().second());
+ }
+ }, ptf.pairs(ptf.ints(), ttype));
+ }
+
}
http://git-wip-us.apache.org/repos/asf/crunch/blob/96f5e9f8/crunch/src/main/java/org/apache/crunch/lib/SampleUtils.java
----------------------------------------------------------------------
diff --git a/crunch/src/main/java/org/apache/crunch/lib/SampleUtils.java b/crunch/src/main/java/org/apache/crunch/lib/SampleUtils.java
new file mode 100644
index 0000000..cbc30e4
--- /dev/null
+++ b/crunch/src/main/java/org/apache/crunch/lib/SampleUtils.java
@@ -0,0 +1,161 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.crunch.lib;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.SortedMap;
+
+import org.apache.crunch.CombineFn;
+import org.apache.crunch.DoFn;
+import org.apache.crunch.Emitter;
+import org.apache.crunch.FilterFn;
+import org.apache.crunch.Pair;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
+class SampleUtils {
+
+ static class SampleFn<S> extends FilterFn<S> {
+
+ private final Long seed;
+ private final double acceptanceProbability;
+ private transient Random r;
+
+ public SampleFn(double acceptanceProbability, Long seed) {
+ Preconditions.checkArgument(0.0 < acceptanceProbability && acceptanceProbability < 1.0);
+ this.seed = seed == null ? System.currentTimeMillis() : seed;
+ this.acceptanceProbability = acceptanceProbability;
+ }
+
+ @Override
+ public void initialize() {
+ if (r == null) {
+ r = new Random(seed);
+ }
+ }
+
+ @Override
+ public boolean accept(S input) {
+ return r.nextDouble() < acceptanceProbability;
+ }
+ }
+
+
+ static class ReservoirSampleFn<T, N extends Number>
+ extends DoFn<Pair<Integer, Pair<T, N>>, Pair<Integer, Pair<Double, T>>> {
+
+ private int[] sampleSizes;
+ private Long seed;
+ private transient List<SortedMap<Double, T>> reservoirs;
+ private transient Random random;
+
+ public ReservoirSampleFn(int[] sampleSizes, Long seed) {
+ this.sampleSizes = sampleSizes;
+ this.seed = seed;
+ }
+
+ @Override
+ public void initialize() {
+ this.reservoirs = Lists.newArrayList();
+ for (int i = 0; i < sampleSizes.length; i++) {
+ reservoirs.add(Maps.<Double, T>newTreeMap());
+ }
+ if (random == null) {
+ if (seed == null) {
+ this.random = new Random();
+ } else {
+ this.random = new Random(seed);
+ }
+ }
+ }
+
+ @Override
+ public void process(Pair<Integer, Pair<T, N>> input,
+ Emitter<Pair<Integer, Pair<Double, T>>> emitter) {
+ int id = input.first();
+ Pair<T, N> p = input.second();
+ double weight = p.second().doubleValue();
+ if (weight > 0.0) {
+ double score = Math.log(random.nextDouble()) / weight;
+ SortedMap<Double, T> reservoir = reservoirs.get(id);
+ if (reservoir.size() < sampleSizes[id]) {
+ reservoir.put(score, p.first());
+ } else if (score > reservoir.firstKey()) {
+ reservoir.remove(reservoir.firstKey());
+ reservoir.put(score, p.first());
+ }
+ }
+ }
+
+ @Override
+ public void cleanup(Emitter<Pair<Integer, Pair<Double, T>>> emitter) {
+ for (int id = 0; id < reservoirs.size(); id++) {
+ SortedMap<Double, T> reservoir = reservoirs.get(id);
+ for (Map.Entry<Double, T> e : reservoir.entrySet()) {
+ emitter.emit(Pair.of(id, Pair.of(e.getKey(), e.getValue())));
+ }
+ }
+ }
+ }
+
+ static class WRSCombineFn<T> extends CombineFn<Integer, Pair<Double, T>> {
+
+ private int[] sampleSizes;
+ private List<SortedMap<Double, T>> reservoirs;
+
+ public WRSCombineFn(int[] sampleSizes) {
+ this.sampleSizes = sampleSizes;
+ }
+
+ @Override
+ public void initialize() {
+ this.reservoirs = Lists.newArrayList();
+ for (int i = 0; i < sampleSizes.length; i++) {
+ reservoirs.add(Maps.<Double, T>newTreeMap());
+ }
+ }
+
+ @Override
+ public void process(Pair<Integer, Iterable<Pair<Double, T>>> input,
+ Emitter<Pair<Integer, Pair<Double, T>>> emitter) {
+ SortedMap<Double, T> reservoir = reservoirs.get(input.first());
+ for (Pair<Double, T> p : input.second()) {
+ if (reservoir.size() < sampleSizes[input.first()]) {
+ reservoir.put(p.first(), p.second());
+ } else if (p.first() > reservoir.firstKey()) {
+ reservoir.remove(reservoir.firstKey());
+ reservoir.put(p.first(), p.second());
+ }
+ }
+ }
+
+ @Override
+ public void cleanup(Emitter<Pair<Integer, Pair<Double, T>>> emitter) {
+ for (int i = 0; i < reservoirs.size(); i++) {
+ SortedMap<Double, T> reservoir = reservoirs.get(i);
+ for (Map.Entry<Double, T> e : reservoir.entrySet()) {
+ emitter.emit(Pair.of(i, Pair.of(e.getKey(), e.getValue())));
+ }
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/crunch/blob/96f5e9f8/crunch/src/test/java/org/apache/crunch/lib/SampleTest.java
----------------------------------------------------------------------
diff --git a/crunch/src/test/java/org/apache/crunch/lib/SampleTest.java b/crunch/src/test/java/org/apache/crunch/lib/SampleTest.java
index 69fd074..bd6fd81 100644
--- a/crunch/src/test/java/org/apache/crunch/lib/SampleTest.java
+++ b/crunch/src/test/java/org/apache/crunch/lib/SampleTest.java
@@ -20,18 +20,51 @@ package org.apache.crunch.lib;
import static org.junit.Assert.assertEquals;
import java.util.List;
+import java.util.Map;
import org.apache.crunch.PCollection;
+import org.apache.crunch.Pair;
import org.apache.crunch.impl.mem.MemPipeline;
+import org.apache.crunch.types.writable.Writables;
import org.junit.Test;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Maps;
public class SampleTest {
+ private PCollection<Pair<String, Double>> values = MemPipeline.typedCollectionOf(
+ Writables.pairs(Writables.strings(), Writables.doubles()),
+ ImmutableList.of(
+ Pair.of("foo", 200.0),
+ Pair.of("bar", 400.0),
+ Pair.of("baz", 100.0),
+ Pair.of("biz", 100.0)));
+
@Test
- public void testSampler() {
+ public void testWRS() throws Exception {
+ Map<String, Integer> histogram = Maps.newHashMap();
+
+ for (int i = 0; i < 100; i++) {
+ PCollection<String> sample = Sample.weightedReservoirSample(values, 1, 1729L + i);
+ for (String s : sample.materialize()) {
+ if (!histogram.containsKey(s)) {
+ histogram.put(s, 1);
+ } else {
+ histogram.put(s, 1 + histogram.get(s));
+ }
+ }
+ }
+
+ Map<String, Integer> expected = ImmutableMap.of(
+ "foo", 24, "bar", 51, "baz", 13, "biz", 12);
+ assertEquals(expected, histogram);
+ }
+
+ @Test
+ public void testSample() {
PCollection<Integer> pcollect = MemPipeline.collectionOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
- Iterable<Integer> sample = Sample.sample(pcollect, 123998, 0.2).materialize();
+ Iterable<Integer> sample = Sample.sample(pcollect, 123998L, 0.2).materialize();
List<Integer> sampleValues = ImmutableList.copyOf(sample);
assertEquals(ImmutableList.of(6, 7), sampleValues);
}