You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@crunch.apache.org by jw...@apache.org on 2013/03/13 17:46:32 UTC

git commit: CRUNCH-178: Add reservoir sampling functions to lib.Sample and make the reservoir and regular sampling APIs consistent.

Updated Branches:
  refs/heads/master 1138d3855 -> 96f5e9f8c


CRUNCH-178: Add reservoir sampling functions to lib.Sample and make the reservoir and
regular sampling APIs consistent.


Project: http://git-wip-us.apache.org/repos/asf/crunch/repo
Commit: http://git-wip-us.apache.org/repos/asf/crunch/commit/96f5e9f8
Tree: http://git-wip-us.apache.org/repos/asf/crunch/tree/96f5e9f8
Diff: http://git-wip-us.apache.org/repos/asf/crunch/diff/96f5e9f8

Branch: refs/heads/master
Commit: 96f5e9f8cbaf387c93204db2e9d15430154124cb
Parents: 1138d38
Author: Josh Wills <jw...@apache.org>
Authored: Wed Mar 6 15:09:28 2013 -0800
Committer: Josh Wills <jw...@apache.org>
Committed: Wed Mar 13 01:04:09 2013 -0700

----------------------------------------------------------------------
 .../main/java/org/apache/crunch/lib/Sample.java    |  197 ++++++++++++---
 .../java/org/apache/crunch/lib/SampleUtils.java    |  161 ++++++++++++
 .../java/org/apache/crunch/lib/SampleTest.java     |   37 +++-
 3 files changed, 356 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/crunch/blob/96f5e9f8/crunch/src/main/java/org/apache/crunch/lib/Sample.java
----------------------------------------------------------------------
diff --git a/crunch/src/main/java/org/apache/crunch/lib/Sample.java b/crunch/src/main/java/org/apache/crunch/lib/Sample.java
index 5be2292..be75ae2 100644
--- a/crunch/src/main/java/org/apache/crunch/lib/Sample.java
+++ b/crunch/src/main/java/org/apache/crunch/lib/Sample.java
@@ -17,51 +17,37 @@
  */
 package org.apache.crunch.lib;
 
-import java.util.Random;
 
-import org.apache.crunch.FilterFn;
+import org.apache.crunch.MapFn;
 import org.apache.crunch.PCollection;
 import org.apache.crunch.PTable;
 import org.apache.crunch.Pair;
+import org.apache.crunch.lib.SampleUtils.ReservoirSampleFn;
+import org.apache.crunch.lib.SampleUtils.SampleFn;
+import org.apache.crunch.lib.SampleUtils.WRSCombineFn;
+import org.apache.crunch.types.PTableType;
+import org.apache.crunch.types.PType;
+import org.apache.crunch.types.PTypeFamily;
 
-import com.google.common.base.Preconditions;
-
+/**
+ * Methods for performing random sampling in a distributed fashion, either by accepting each
+ * record in a {@code PCollection} with an independent probability in order to sample some
+ * fraction of the overall data set, or by using reservoir sampling in order to pull a uniform
+ * or weighted sample of fixed size from a {@code PCollection} of an unknown size. For more details
+ * on the reservoir sampling algorithms used by this library, see the A-ES algorithm described in
+ * <a href="http://arxiv.org/pdf/1012.0256.pdf">Efraimidis (2012)</a>.
+ */
 public class Sample {
 
-  private static class SamplerFn<S> extends FilterFn<S> {
-
-    private final long seed;
-    private final double acceptanceProbability;
-    private transient Random r;
-
-    public SamplerFn(long seed, double acceptanceProbability) {
-      Preconditions.checkArgument(0.0 < acceptanceProbability && acceptanceProbability < 1.0);
-      this.seed = seed;
-      this.acceptanceProbability = acceptanceProbability;
-    }
-
-    @Override
-    public void initialize() {
-      if (r == null) {
-        r = new Random(seed);
-      }
-    }
-
-    @Override
-    public boolean accept(S input) {
-      return r.nextDouble() < acceptanceProbability;
-    }
-  }
-
   /**
    * Output records from the given {@code PCollection} with the given probability.
    * 
    * @param input The {@code PCollection} to sample from
-   * @param probability The probability (0.0 < p < 1.0)
+   * @param probability The probability (0.0 &lt; p %lt; 1.0)
    * @return The output {@code PCollection} created from sampling
    */
   public static <S> PCollection<S> sample(PCollection<S> input, double probability) {
-    return sample(input, System.currentTimeMillis(), probability);
+    return sample(input, null, probability);
   }
 
   /**
@@ -69,26 +55,163 @@ public class Sample {
    * testing.
    * 
    * @param input The {@code PCollection} to sample from
-   * @param seed The seed
-   * @param probability The probability (0.0 < p < 1.0)
+   * @param seed The seed for the random number generator
+   * @param probability The probability (0.0 &lt; p &lt; 1.0)
    * @return The output {@code PCollection} created from sampling
    */
-  public static <S> PCollection<S> sample(PCollection<S> input, long seed, double probability) {
+  public static <S> PCollection<S> sample(PCollection<S> input, Long seed, double probability) {
     String stageName = String.format("sample(%.2f)", probability);
-    return input.parallelDo(stageName, new SamplerFn<S>(seed, probability), input.getPType());
+    return input.parallelDo(stageName, new SampleFn<S>(probability, seed), input.getPType());
   }
   
   /**
    * A {@code PTable<K, V>} analogue of the {@code sample} function.
+   * 
+   * @param input The {@code PTable} to sample from
+   * @param probability The probability (0.0 &lt; p &lt; 1.0)
+   * @return The output {@code PTable} created from sampling
    */
   public static <K, V> PTable<K, V> sample(PTable<K, V> input, double probability) {
     return PTables.asPTable(sample((PCollection<Pair<K, V>>) input, probability));
   }
   
   /**
-   * A {@code PTable<K, V>} analogue of the {@code sample} function.
+   * A {@code PTable<K, V>} analogue of the {@code sample} function, with the seed argument
+   * exposed for testing purposes.
+   * 
+   * @param input The {@code PTable} to sample from
+   * @param seed The seed for the random number generator
+   * @param probability The probability (0.0 &lt; p &lt; 1.0)
+   * @return The output {@code PTable} created from sampling
    */
-  public static <K, V> PTable<K, V> sample(PTable<K, V> input, long seed, double probability) {
+  public static <K, V> PTable<K, V> sample(PTable<K, V> input, Long seed, double probability) {
     return PTables.asPTable(sample((PCollection<Pair<K, V>>) input, seed, probability));
   }
+  
+  /**
+   * Select a fixed number of elements from the given {@code PCollection} with each element
+   * equally likely to be included in the sample.
+   * 
+   * @param input The input data
+   * @param sampleSize The number of elements to select
+   * @return A {@code PCollection} made up of the sampled elements
+   */
+  public static <T> PCollection<T> reservoirSample(
+      PCollection<T> input,
+      int sampleSize) {
+    return reservorSample(input, sampleSize, null);
+  }
+
+  /**
+   * A version of the reservoir sampling algorithm that uses a given seed, primarily for
+   * testing purposes.
+   * 
+   * @param input The input data
+   * @param sampleSize The number of elements to select
+   * @param seed The test seed
+   * @return A {@code PCollection} made up of the sampled elements
+
+   */
+  public static <T> PCollection<T> reservorSample(
+      PCollection<T> input,
+      int sampleSize,
+      Long seed) {
+    PTypeFamily ptf = input.getTypeFamily();
+    PType<Pair<T, Integer>> ptype = ptf.pairs(input.getPType(), ptf.ints());
+    return weightedReservoirSample(
+        input.parallelDo(new MapFn<T, Pair<T, Integer>>() {
+          public Pair<T, Integer> map(T t) { return Pair.of(t, 1); }
+        }, ptype),
+        sampleSize,
+        seed);
+  }
+  
+  /**
+   * Selects a weighted sample of the elements of the given {@code PCollection}, where the second term in
+   * the input {@code Pair} is a numerical weight.
+   * 
+   * @param input the weighted observations
+   * @param sampleSize The number of elements to select
+   * @return A random sample of the given size that respects the weighting values
+   */
+  public static <T, N extends Number> PCollection<T> weightedReservoirSample(
+      PCollection<Pair<T, N>> input,
+      int sampleSize) {
+    return weightedReservoirSample(input, sampleSize, null);
+  }
+  
+  /**
+   * The weighted reservoir sampling function with the seed term exposed for testing purposes.
+   * 
+   * @param input the weighted observations
+   * @param sampleSize The number of elements to select
+   * @param seed The test seed
+   * @return A random sample of the given size that respects the weighting values
+   */
+  public static <T, N extends Number> PCollection<T> weightedReservoirSample(
+      PCollection<Pair<T, N>> input,
+      int sampleSize,
+      Long seed) {
+    PTypeFamily ptf = input.getTypeFamily();
+    PTable<Integer, Pair<T, N>> groupedIn = input.parallelDo(
+        new MapFn<Pair<T, N>, Pair<Integer, Pair<T, N>>>() {
+          @Override
+          public Pair<Integer, Pair<T, N>> map(Pair<T, N> p) {
+            return Pair.of(0, p);
+          }
+        }, ptf.tableOf(ptf.ints(), input.getPType()));
+    int[] ss = new int[] { sampleSize };
+    return groupedWeightedReservoirSample(groupedIn, ss, seed)
+        .parallelDo(new MapFn<Pair<Integer, T>, T>() {
+          @Override
+          public T map(Pair<Integer, T> p) {
+            return p.second();
+          }
+        }, (PType<T>) input.getPType().getSubTypes().get(0));
+  }
+  
+  /**
+   * The most general purpose of the weighted reservoir sampling patterns that allows us to choose
+   * a random sample of elements for each of N input groups.
+   * 
+   * @param input A {@code PTable} with the key a group ID and the value a weighted observation in that group
+   * @param sampleSizes An array of length N, with each entry is the number of elements to include in that group
+   * @return A {@code PCollection} of the sampled elements for each of the groups
+   */
+  
+  public static <T, N extends Number> PCollection<Pair<Integer, T>> groupedWeightedReservoirSample(
+      PTable<Integer, Pair<T, N>> input,
+      int[] sampleSizes) {
+    return groupedWeightedReservoirSample(input, sampleSizes, null);
+  }
+  
+  /**
+   * Same as the other groupedWeightedReservoirSample method, but include a seed for testing
+   * purposes.
+   * 
+   * @param input A {@code PTable} with the key a group ID and the value a weighted observation in that group
+   * @param sampleSizes An array of length N, with each entry is the number of elements to include in that group
+   * @param seed The test seed
+   * @return A {@code PCollection} of the sampled elements for each of the groups
+   */
+  public static <T, N extends Number> PCollection<Pair<Integer, T>> groupedWeightedReservoirSample(
+      PTable<Integer, Pair<T, N>> input,
+      int[] sampleSizes,
+      Long seed) {
+    PTypeFamily ptf = input.getTypeFamily();
+    PType<T> ttype = (PType<T>) input.getPTableType().getValueType().getSubTypes().get(0);
+    PTableType<Integer, Pair<Double, T>> ptt = ptf.tableOf(ptf.ints(),
+        ptf.pairs(ptf.doubles(), ttype));
+    
+    return input.parallelDo(new ReservoirSampleFn<T, N>(sampleSizes, seed), ptt)
+        .groupByKey(1)
+        .combineValues(new WRSCombineFn<T>(sampleSizes))
+        .parallelDo(new MapFn<Pair<Integer, Pair<Double, T>>, Pair<Integer, T>>() {
+          @Override
+          public Pair<Integer, T> map(Pair<Integer, Pair<Double, T>> p) {
+            return Pair.of(p.first(), p.second().second());
+          }
+        }, ptf.pairs(ptf.ints(), ttype));
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/crunch/blob/96f5e9f8/crunch/src/main/java/org/apache/crunch/lib/SampleUtils.java
----------------------------------------------------------------------
diff --git a/crunch/src/main/java/org/apache/crunch/lib/SampleUtils.java b/crunch/src/main/java/org/apache/crunch/lib/SampleUtils.java
new file mode 100644
index 0000000..cbc30e4
--- /dev/null
+++ b/crunch/src/main/java/org/apache/crunch/lib/SampleUtils.java
@@ -0,0 +1,161 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.crunch.lib;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.SortedMap;
+
+import org.apache.crunch.CombineFn;
+import org.apache.crunch.DoFn;
+import org.apache.crunch.Emitter;
+import org.apache.crunch.FilterFn;
+import org.apache.crunch.Pair;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
+class SampleUtils {
+  
+  static class SampleFn<S> extends FilterFn<S> {
+
+    private final Long seed;
+    private final double acceptanceProbability;
+    private transient Random r;
+
+    public SampleFn(double acceptanceProbability, Long seed) {
+      Preconditions.checkArgument(0.0 < acceptanceProbability && acceptanceProbability < 1.0);
+      this.seed = seed == null ? System.currentTimeMillis() : seed;
+      this.acceptanceProbability = acceptanceProbability;
+    }
+
+    @Override
+    public void initialize() {
+      if (r == null) {
+        r = new Random(seed);
+      }
+    }
+
+    @Override
+    public boolean accept(S input) {
+      return r.nextDouble() < acceptanceProbability;
+    }
+  }
+
+
+  static class ReservoirSampleFn<T, N extends Number>
+      extends DoFn<Pair<Integer, Pair<T, N>>, Pair<Integer, Pair<Double, T>>> {
+  
+    private int[] sampleSizes;
+    private Long seed;
+    private transient List<SortedMap<Double, T>> reservoirs;
+    private transient Random random;
+    
+    public ReservoirSampleFn(int[] sampleSizes, Long seed) {
+      this.sampleSizes = sampleSizes;
+      this.seed = seed;
+    }
+    
+    @Override
+    public void initialize() {
+      this.reservoirs = Lists.newArrayList();
+      for (int i = 0; i < sampleSizes.length; i++) {
+        reservoirs.add(Maps.<Double, T>newTreeMap());
+      }
+      if (random == null) {
+        if (seed == null) {
+          this.random = new Random();
+        } else {
+          this.random = new Random(seed);
+        }
+      }
+    }
+    
+    @Override
+    public void process(Pair<Integer, Pair<T, N>> input,
+        Emitter<Pair<Integer, Pair<Double, T>>> emitter) {
+      int id = input.first();
+      Pair<T, N> p = input.second();
+      double weight = p.second().doubleValue();
+      if (weight > 0.0) {
+        double score = Math.log(random.nextDouble()) / weight;
+        SortedMap<Double, T> reservoir = reservoirs.get(id);
+        if (reservoir.size() < sampleSizes[id]) { 
+          reservoir.put(score, p.first());        
+        } else if (score > reservoir.firstKey()) {
+          reservoir.remove(reservoir.firstKey());
+          reservoir.put(score, p.first());
+        }
+      }
+    }
+    
+    @Override
+    public void cleanup(Emitter<Pair<Integer, Pair<Double, T>>> emitter) {
+      for (int id = 0; id < reservoirs.size(); id++) {
+        SortedMap<Double, T> reservoir = reservoirs.get(id);
+        for (Map.Entry<Double, T> e : reservoir.entrySet()) {
+          emitter.emit(Pair.of(id, Pair.of(e.getKey(), e.getValue())));
+        }
+      }
+    }
+  }
+  
+  static class WRSCombineFn<T> extends CombineFn<Integer, Pair<Double, T>> {
+
+    private int[] sampleSizes;
+    private List<SortedMap<Double, T>> reservoirs;
+    
+    public WRSCombineFn(int[] sampleSizes) {
+      this.sampleSizes = sampleSizes;
+    }
+
+    @Override
+    public void initialize() {
+      this.reservoirs = Lists.newArrayList();
+      for (int i = 0; i < sampleSizes.length; i++) {
+        reservoirs.add(Maps.<Double, T>newTreeMap());
+      }
+    }
+    
+    @Override
+    public void process(Pair<Integer, Iterable<Pair<Double, T>>> input,
+        Emitter<Pair<Integer, Pair<Double, T>>> emitter) {
+      SortedMap<Double, T> reservoir = reservoirs.get(input.first());
+      for (Pair<Double, T> p : input.second()) {
+        if (reservoir.size() < sampleSizes[input.first()]) { 
+          reservoir.put(p.first(), p.second());        
+        } else if (p.first() > reservoir.firstKey()) {
+          reservoir.remove(reservoir.firstKey());
+          reservoir.put(p.first(), p.second());  
+        }
+      }
+    }
+    
+    @Override
+    public void cleanup(Emitter<Pair<Integer, Pair<Double, T>>> emitter) {
+      for (int i = 0; i < reservoirs.size(); i++) {
+        SortedMap<Double, T> reservoir = reservoirs.get(i);
+        for (Map.Entry<Double, T> e : reservoir.entrySet()) {
+          emitter.emit(Pair.of(i, Pair.of(e.getKey(), e.getValue())));
+        }
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/crunch/blob/96f5e9f8/crunch/src/test/java/org/apache/crunch/lib/SampleTest.java
----------------------------------------------------------------------
diff --git a/crunch/src/test/java/org/apache/crunch/lib/SampleTest.java b/crunch/src/test/java/org/apache/crunch/lib/SampleTest.java
index 69fd074..bd6fd81 100644
--- a/crunch/src/test/java/org/apache/crunch/lib/SampleTest.java
+++ b/crunch/src/test/java/org/apache/crunch/lib/SampleTest.java
@@ -20,18 +20,51 @@ package org.apache.crunch.lib;
 import static org.junit.Assert.assertEquals;
 
 import java.util.List;
+import java.util.Map;
 
 import org.apache.crunch.PCollection;
+import org.apache.crunch.Pair;
 import org.apache.crunch.impl.mem.MemPipeline;
+import org.apache.crunch.types.writable.Writables;
 import org.junit.Test;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Maps;
 
 public class SampleTest {
+  private PCollection<Pair<String, Double>> values = MemPipeline.typedCollectionOf(
+      Writables.pairs(Writables.strings(), Writables.doubles()),
+      ImmutableList.of(
+        Pair.of("foo", 200.0),
+        Pair.of("bar", 400.0),
+        Pair.of("baz", 100.0),
+        Pair.of("biz", 100.0)));
+  
   @Test
-  public void testSampler() {
+  public void testWRS() throws Exception {
+    Map<String, Integer> histogram = Maps.newHashMap();
+    
+    for (int i = 0; i < 100; i++) {
+      PCollection<String> sample = Sample.weightedReservoirSample(values, 1, 1729L + i);
+      for (String s : sample.materialize()) {
+        if (!histogram.containsKey(s)) {
+          histogram.put(s, 1);
+        } else {
+          histogram.put(s, 1 + histogram.get(s));
+        }
+      }
+    }
+    
+    Map<String, Integer> expected = ImmutableMap.of(
+        "foo", 24, "bar", 51, "baz", 13, "biz", 12);
+    assertEquals(expected, histogram);
+  }
+
+  @Test
+  public void testSample() {
     PCollection<Integer> pcollect = MemPipeline.collectionOf(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
-    Iterable<Integer> sample = Sample.sample(pcollect, 123998, 0.2).materialize();
+    Iterable<Integer> sample = Sample.sample(pcollect, 123998L, 0.2).materialize();
     List<Integer> sampleValues = ImmutableList.copyOf(sample);
     assertEquals(ImmutableList.of(6, 7), sampleValues);
   }