You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/11/05 19:30:22 UTC

git commit: [branch-1.1][SPARK-4148][PySpark] fix seed distribution and add some tests for rdd.sample

Repository: spark
Updated Branches:
  refs/heads/branch-1.1 1b282cdfd -> 44751af9f


[branch-1.1][SPARK-4148][PySpark] fix seed distribution and add some tests for rdd.sample

Port #3010 to branch-1.1.

Author: Xiangrui Meng <me...@databricks.com>

Closes #3104 from mengxr/SPARK-4148-1.1 and squashes the following commits:

684c002 [Xiangrui Meng] apply SPARK-4148 to branch-1.1


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/44751af9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/44751af9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/44751af9

Branch: refs/heads/branch-1.1
Commit: 44751af9f8ec6a2b6ca49e5aee3e924c61afd3f7
Parents: 1b282cd
Author: Xiangrui Meng <me...@databricks.com>
Authored: Wed Nov 5 10:30:10 2014 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Wed Nov 5 10:30:10 2014 -0800

----------------------------------------------------------------------
 python/pyspark/rdd.py        |  3 ---
 python/pyspark/rddsampler.py | 11 +++++------
 python/pyspark/tests.py      | 15 +++++++++++++++
 3 files changed, 20 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/44751af9/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 2b47b6c..3f81550 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -417,9 +417,6 @@ class RDD(object):
         """
         Return a sampled subset of this RDD (relies on numpy and falls back
         on default random generator if numpy is unavailable).
-
-        >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
-        [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
         """
         assert fraction >= 0.0, "Negative fraction value: %s" % fraction
         return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)

http://git-wip-us.apache.org/repos/asf/spark/blob/44751af9/python/pyspark/rddsampler.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index 55e247d..a6e8106 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -40,14 +40,13 @@ class RDDSamplerBase(object):
     def initRandomGenerator(self, split):
         if self._use_numpy:
             import numpy
-            self._random = numpy.random.RandomState(self._seed)
+            self._random = numpy.random.RandomState(self._seed ^ split)
         else:
-            self._random = random.Random(self._seed)
+            self._random = random.Random(self._seed ^ split)
 
-        for _ in range(0, split):
-            # discard the next few values in the sequence to have a
-            # different seed for the different splits
-            self._random.randint(0, sys.maxint)
+        # mixing because the initial seeds are close to each other
+        for _ in xrange(10):
+            self._random.randint(0, 1)
 
         self._split = split
         self._rand_initialized = True

http://git-wip-us.apache.org/repos/asf/spark/blob/44751af9/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 8f0a351..5cea1b0 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -470,6 +470,21 @@ class TestRDDFunctions(PySparkTestCase):
         self.assertEquals(([1, "b"], [5]), rdd.histogram(1))
         self.assertRaises(TypeError, lambda: rdd.histogram(2))
 
+    def test_sample(self):
+        rdd = self.sc.parallelize(range(0, 100), 4)
+        wo = rdd.sample(False, 0.1, 2).collect()
+        wo_dup = rdd.sample(False, 0.1, 2).collect()
+        self.assertSetEqual(set(wo), set(wo_dup))
+        wr = rdd.sample(True, 0.2, 5).collect()
+        wr_dup = rdd.sample(True, 0.2, 5).collect()
+        self.assertSetEqual(set(wr), set(wr_dup))
+        wo_s10 = rdd.sample(False, 0.3, 10).collect()
+        wo_s20 = rdd.sample(False, 0.3, 20).collect()
+        self.assertNotEqual(set(wo_s10), set(wo_s20))
+        wr_s11 = rdd.sample(True, 0.4, 11).collect()
+        wr_s21 = rdd.sample(True, 0.4, 21).collect()
+        self.assertNotEqual(set(wr_s11), set(wr_s21))
+
 
 class TestSQL(PySparkTestCase):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org