You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/07/25 08:42:42 UTC

git commit: [SPARK-2656] Python version of stratified sampling

Repository: spark
Updated Branches:
  refs/heads/master 14174abd4 -> 2f75a4a30


[SPARK-2656] Python version of stratified sampling

exact sample size not supported for now.

Author: Doris Xin <do...@gmail.com>

Closes #1554 from dorx/pystratified and squashes the following commits:

4ba927a [Doris Xin] use rel diff (+- 50%) instead of abs diff (+- 50)
bdc3f8b [Doris Xin] updated unit to check sample holistically
7713c7b [Doris Xin] Python version of stratified sampling


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2f75a4a3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2f75a4a3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2f75a4a3

Branch: refs/heads/master
Commit: 2f75a4a30e1a3fdf384475b9660c6c43f093f68c
Parents: 14174ab
Author: Doris Xin <do...@gmail.com>
Authored: Thu Jul 24 23:42:08 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu Jul 24 23:42:08 2014 -0700

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/rdd/RDD.scala   |  2 +-
 python/pyspark/rdd.py                           | 25 ++++++++++++++--
 python/pyspark/rddsampler.py                    | 30 ++++++++++++++++++--
 3 files changed, 51 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2f75a4a3/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index c1bafab..edbf7ea 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -354,7 +354,7 @@ abstract class RDD[T: ClassTag](
   def sample(withReplacement: Boolean, 
       fraction: Double, 
       seed: Long = Utils.random.nextLong): RDD[T] = {
-    require(fraction >= 0.0, "Invalid fraction value: " + fraction)
+    require(fraction >= 0.0, "Negative fraction value: " + fraction)
     if (withReplacement) {
       new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed)
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/2f75a4a3/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7ad6108..113a082 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -39,7 +39,7 @@ from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
 from pyspark.join import python_join, python_left_outer_join, \
     python_right_outer_join, python_cogroup
 from pyspark.statcounter import StatCounter
-from pyspark.rddsampler import RDDSampler
+from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
 from pyspark.storagelevel import StorageLevel
 from pyspark.resultiterable import ResultIterable
 from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
@@ -411,7 +411,7 @@ class RDD(object):
         >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
         [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
         """
-        assert fraction >= 0.0, "Invalid fraction value: %s" % fraction
+        assert fraction >= 0.0, "Negative fraction value: %s" % fraction
         return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
 
     # this is ported from scala/spark/RDD.scala
@@ -1456,6 +1456,27 @@ class RDD(object):
         """
         return python_cogroup((self, other), numPartitions)
 
+    def sampleByKey(self, withReplacement, fractions, seed=None):
+        """
+        Return a subset of this RDD sampled by key (via stratified sampling).
+        Create a sample of this RDD using variable sampling rates for
+        different keys as specified by fractions, a key to sampling rate map.
+
+        >>> fractions = {"a": 0.2, "b": 0.1}
+        >>> rdd = sc.parallelize(fractions.keys()).cartesian(sc.parallelize(range(0, 1000)))
+        >>> sample = dict(rdd.sampleByKey(False, fractions, 2).groupByKey().collect())
+        >>> 100 < len(sample["a"]) < 300 and 50 < len(sample["b"]) < 150
+        True
+        >>> max(sample["a"]) <= 999 and min(sample["a"]) >= 0
+        True
+        >>> max(sample["b"]) <= 999 and min(sample["b"]) >= 0
+        True
+        """
+        for fraction in fractions.values():
+            assert fraction >= 0.0, "Negative fraction value: %s" % fraction
+        return self.mapPartitionsWithIndex( \
+            RDDStratifiedSampler(withReplacement, fractions, seed).func, True)
+
     def subtractByKey(self, other, numPartitions=None):
         """
         Return each (key, value) pair in C{self} that has no pair with matching

http://git-wip-us.apache.org/repos/asf/spark/blob/2f75a4a3/python/pyspark/rddsampler.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index 7ff1c31..2df000f 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -19,8 +19,8 @@ import sys
 import random
 
 
-class RDDSampler(object):
-    def __init__(self, withReplacement, fraction, seed=None):
+class RDDSamplerBase(object):
+    def __init__(self, withReplacement, seed=None):
         try:
             import numpy
             self._use_numpy = True
@@ -32,7 +32,6 @@ class RDDSampler(object):
 
         self._seed = seed if seed is not None else random.randint(0, sys.maxint)
         self._withReplacement = withReplacement
-        self._fraction = fraction
         self._random = None
         self._split = None
         self._rand_initialized = False
@@ -94,6 +93,12 @@ class RDDSampler(object):
         else:
             self._random.shuffle(vals, self._random.random)
 
+
+class RDDSampler(RDDSamplerBase):
+    def __init__(self, withReplacement, fraction, seed=None):
+        RDDSamplerBase.__init__(self, withReplacement, seed)
+        self._fraction = fraction
+
     def func(self, split, iterator):
         if self._withReplacement:
             for obj in iterator:
@@ -107,3 +112,22 @@ class RDDSampler(object):
             for obj in iterator:
                 if self.getUniformSample(split) <= self._fraction:
                     yield obj
+
+class RDDStratifiedSampler(RDDSamplerBase):
+    def __init__(self, withReplacement, fractions, seed=None):
+        RDDSamplerBase.__init__(self, withReplacement, seed)
+        self._fractions = fractions
+
+    def func(self, split, iterator):
+        if self._withReplacement:
+            for key, val in iterator:
+                # For large datasets, the expected number of occurrences of each element in
+                # a sample with replacement is Poisson(frac). We use that to get a count for
+                # each element.
+                count = self.getPoissonSample(split, mean=self._fractions[key])
+                for _ in range(0, count):
+                    yield key, val
+        else:
+            for key, val in iterator:
+                if self.getUniformSample(split) <= self._fractions[key]:
+                    yield key, val