You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2014/12/19 04:58:04 UTC
spark git commit: [branch-1.0][SPARK-4148][PySpark] fix seed
distribution and add some tests for rdd.sample
Repository: spark
Updated Branches:
refs/heads/branch-1.0 d2f86331d -> e0fc0c56f
[branch-1.0][SPARK-4148][PySpark] fix seed distribution and add some tests for rdd.sample
Port #3010 to branch-1.0.
Author: Xiangrui Meng <me...@databricks.com>
Closes #3106 from mengxr/SPARK-4148-1.0 and squashes the following commits:
c834cee [Xiangrui Meng] apply SPARK-4148 to branch-1.0
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e0fc0c56
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e0fc0c56
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e0fc0c56
Branch: refs/heads/branch-1.0
Commit: e0fc0c56fa86e863f6ddf1f67a158d8b714e75c3
Parents: d2f8633
Author: Xiangrui Meng <me...@databricks.com>
Authored: Thu Dec 18 19:57:36 2014 -0800
Committer: Josh Rosen <jo...@databricks.com>
Committed: Thu Dec 18 19:57:36 2014 -0800
----------------------------------------------------------------------
python/pyspark/rdd.py | 3 ---
python/pyspark/rddsampler.py | 11 +++++------
python/pyspark/tests.py | 15 +++++++++++++++
3 files changed, 20 insertions(+), 9 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/e0fc0c56/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 57c2cd7..1c3b5e4 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -366,9 +366,6 @@ class RDD(object):
"""
Return a sampled subset of this RDD (relies on numpy and falls back
on default random generator if numpy is unavailable).
-
- >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
- [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
"""
assert fraction >= 0.0, "Invalid fraction value: %s" % fraction
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
http://git-wip-us.apache.org/repos/asf/spark/blob/e0fc0c56/python/pyspark/rddsampler.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index 845a267..fac2008 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -37,14 +37,13 @@ class RDDSampler(object):
def initRandomGenerator(self, split):
if self._use_numpy:
import numpy
- self._random = numpy.random.RandomState(self._seed)
+ self._random = numpy.random.RandomState(self._seed ^ split)
else:
- self._random = random.Random(self._seed)
+ self._random = random.Random(self._seed ^ split)
- for _ in range(0, split):
- # discard the next few values in the sequence to have a
- # different seed for the different splits
- self._random.randint(0, sys.maxint)
+ # mixing because the initial seeds are close to each other
+ for _ in xrange(10):
+ self._random.randint(0, 1)
self._split = split
self._rand_initialized = True
http://git-wip-us.apache.org/repos/asf/spark/blob/e0fc0c56/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 6ba76ec..0605893 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -216,6 +216,21 @@ class TestRDDFunctions(PySparkTestCase):
self.assertEqual([1], rdd.map(itemgetter(1)).collect())
self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect())
+ def test_sample(self):
+ rdd = self.sc.parallelize(range(0, 100), 4)
+ wo = rdd.sample(False, 0.1, 2).collect()
+ wo_dup = rdd.sample(False, 0.1, 2).collect()
+ self.assertSetEqual(set(wo), set(wo_dup))
+ wr = rdd.sample(True, 0.2, 5).collect()
+ wr_dup = rdd.sample(True, 0.2, 5).collect()
+ self.assertSetEqual(set(wr), set(wr_dup))
+ wo_s10 = rdd.sample(False, 0.3, 10).collect()
+ wo_s20 = rdd.sample(False, 0.3, 20).collect()
+ self.assertNotEqual(set(wo_s10), set(wo_s20))
+ wr_s11 = rdd.sample(True, 0.4, 11).collect()
+ wr_s21 = rdd.sample(True, 0.4, 21).collect()
+ self.assertNotEqual(set(wr_s11), set(wr_s21))
+
class TestIO(PySparkTestCase):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org