You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/08/27 23:27:46 UTC
[spark] branch master updated: [SPARK-40240][PYTHON] PySpark rdd.takeSample should correctly validate `num > maxSampleSize`
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f555a0d80e1 [SPARK-40240][PYTHON] PySpark rdd.takeSample should correctly validate `num > maxSampleSize`
f555a0d80e1 is described below
commit f555a0d80e1858ca30527328ca240b56ae6f415e
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Sun Aug 28 07:27:13 2022 +0800
[SPARK-40240][PYTHON] PySpark rdd.takeSample should correctly validate `num > maxSampleSize`
### What changes were proposed in this pull request?
to make the PySpark `rdd.takeSample` behave like the scala side
### Why are the changes needed?
`rdd.takeSample` in Spark-Core checks the `num > maxsize - int(numStDev * sqrt(maxsize))` at first, while in the PySpark, it may skip this validation:
```scala
scala> sc.range(0, 10).takeSample(false, Int.MaxValue)
java.lang.IllegalArgumentException: requirement failed: Cannot support a sample size > Int.MaxValue - 10.0 * math.sqrt(Int.MaxValue)
at scala.Predef$.require(Predef.scala:281)
at org.apache.spark.rdd.RDD.$anonfun$takeSample$1(RDD.scala:620)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
at org.apache.spark.rdd.RDD.takeSample(RDD.scala:615)
... 47 elided
```
```python
In [2]: sc.range(0, 10).takeSample(False, sys.maxsize)
Out[2]: [9, 6, 8, 5, 7, 2, 0, 3, 4, 1]
```
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
added doctest
Closes #37683 from zhengruifeng/py_refine_takesample.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
python/pyspark/rdd.py | 16 ++++++++++------
1 file changed, 10 insertions(+), 6 deletions(-)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index b631f141a89..5fe463233a2 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1122,6 +1122,7 @@ class RDD(Generic[T_co]):
Examples
--------
+ >>> import sys
>>> rdd = sc.parallelize(range(0, 10))
>>> len(rdd.takeSample(True, 20, 1))
20
@@ -1129,12 +1130,19 @@ class RDD(Generic[T_co]):
5
>>> len(rdd.takeSample(False, 15, 3))
10
+ >>> sc.range(0, 10).takeSample(False, sys.maxsize)
+ Traceback (most recent call last):
+ ...
+ ValueError: Sample size cannot be greater than ...
"""
numStDev = 10.0
-
+ maxSampleSize = sys.maxsize - int(numStDev * sqrt(sys.maxsize))
if num < 0:
raise ValueError("Sample size cannot be negative.")
- elif num == 0:
+ elif num > maxSampleSize:
+ raise ValueError("Sample size cannot be greater than %d." % maxSampleSize)
+
+ if num == 0 or self.getNumPartitions() == 0:
return []
initialCount = self.count()
@@ -1149,10 +1157,6 @@ class RDD(Generic[T_co]):
rand.shuffle(samples)
return samples
- maxSampleSize = sys.maxsize - int(numStDev * sqrt(sys.maxsize))
- if num > maxSampleSize:
- raise ValueError("Sample size cannot be greater than %d." % maxSampleSize)
-
fraction = RDD._computeFractionForSampleSize(num, initialCount, withReplacement)
samples = self.sample(withReplacement, fraction, seed).collect()
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org