You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2022/06/05 23:51:00 UTC
[spark] branch master updated: [SPARK-39179][PYTHON][TESTS] Improve the test coverage for pyspark/shuffle.py
This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new bab70b1ef24 [SPARK-39179][PYTHON][TESTS] Improve the test coverage for pyspark/shuffle.py
bab70b1ef24 is described below
commit bab70b1ef24a2461395b32f609a9274269cb000e
Author: pralabhkumar <pr...@gmail.com>
AuthorDate: Sun Jun 5 16:50:29 2022 -0700
[SPARK-39179][PYTHON][TESTS] Improve the test coverage for pyspark/shuffle.py
### What changes were proposed in this pull request?
This PR add test cases for shuffle.py
### Why are the changes needed?
To cover corner test cases and increase coverage. This will increase the coverage of shuffle.py to close to 90%
### Does this PR introduce _any_ user-facing change?
No - test only
### How was this patch tested?
CI in this PR should test it out
Closes #36701 from pralabhkumar/rk_test_taskcontext.
Authored-by: pralabhkumar <pr...@gmail.com>
Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
python/pyspark/tests/test_shuffle.py | 94 +++++++++++++++++++++++++++++++++++-
1 file changed, 93 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/tests/test_shuffle.py b/python/pyspark/tests/test_shuffle.py
index cea29c79357..fb11a84f8af 100644
--- a/python/pyspark/tests/test_shuffle.py
+++ b/python/pyspark/tests/test_shuffle.py
@@ -16,11 +16,20 @@
#
import random
import unittest
+import tempfile
+import os
from py4j.protocol import Py4JJavaError
from pyspark import shuffle, CPickleSerializer, SparkConf, SparkContext
-from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter
+from pyspark.shuffle import (
+ Aggregator,
+ ExternalMerger,
+ ExternalSorter,
+ SimpleAggregator,
+ Merger,
+ ExternalGroupBy,
+)
class MergerTests(unittest.TestCase):
@@ -54,6 +63,57 @@ class MergerTests(unittest.TestCase):
self.assertTrue(m.spills >= 1)
self.assertEqual(sum(sum(v) for k, v in m.items()), sum(range(self.N)) * 3)
+ def test_shuffle_data_with_multiple_locations(self):
+ # SPARK-39179: Test shuffle of data with multiple location also check
+ # shuffle locations get randomized
+
+ with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2:
+ original = os.environ.get("SPARK_LOCAL_DIRS", None)
+ os.environ["SPARK_LOCAL_DIRS"] = tempdir1 + "," + tempdir2
+ try:
+ index_of_tempdir1 = [False, False]
+ for idx in range(10):
+ m = ExternalMerger(self.agg, 20)
+ if m.localdirs[0].startswith(tempdir1):
+ index_of_tempdir1[0] = True
+ elif m.localdirs[1].startswith(tempdir1):
+ index_of_tempdir1[1] = True
+ m.mergeValues(self.data)
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(sum(v) for k, v in m.items()), sum(range(self.N)))
+ self.assertTrue(
+ index_of_tempdir1[0] and (index_of_tempdir1[0] == index_of_tempdir1[1])
+ )
+ finally:
+ if original is not None:
+ os.environ["SPARK_LOCAL_DIRS"] = original
+ else:
+ del os.environ["SPARK_LOCAL_DIRS"]
+
+ def test_simple_aggregator_with_medium_dataset(self):
+ # SPARK-39179: Test Simple aggregator
+ agg = SimpleAggregator(lambda x, y: x + y)
+ m = ExternalMerger(agg, 20)
+ m.mergeValues(self.data)
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(v for k, v in m.items()), sum(range(self.N)))
+
+ def test_merger_not_implemented_error(self):
+ # SPARK-39179: Test Merger for error scenarios
+ agg = SimpleAggregator(lambda x, y: x + y)
+
+ class DummyMerger(Merger):
+ def __init__(self, agg):
+ Merger.__init__(self, agg)
+
+ dummy_merger = DummyMerger(agg)
+ with self.assertRaises(NotImplementedError):
+ dummy_merger.mergeValues(self.data)
+ with self.assertRaises(NotImplementedError):
+ dummy_merger.mergeCombiners(self.data)
+ with self.assertRaises(NotImplementedError):
+ dummy_merger.items()
+
def test_huge_dataset(self):
m = ExternalMerger(self.agg, 5, partitions=3)
m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10))
@@ -117,6 +177,38 @@ class MergerTests(unittest.TestCase):
m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data))
+class ExternalGroupByTests(unittest.TestCase):
+ def setUp(self):
+ self.N = 1 << 20
+ values = [i for i in range(self.N)]
+ keys = [i for i in range(2)]
+ import itertools
+
+ self.data = [value for value in itertools.product(keys, values)]
+ self.agg = Aggregator(
+ lambda x: [x], lambda x, y: x.append(y) or x, lambda x, y: x.extend(y) or x
+ )
+
+ def test_medium_dataset(self):
+ # SPARK-39179: Test external group by for medium dataset
+ m = ExternalGroupBy(self.agg, 5, partitions=3)
+ m.mergeValues(self.data)
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(sum(v) for k, v in m.items()), 2 * sum(range(self.N)))
+
+ def test_dataset_with_keys_are_unsorted(self):
+ # SPARK-39179: Test external group when numbers of keys are greater than SORT KEY Limit.
+ m = ExternalGroupBy(self.agg, 5, partitions=3)
+ original = m.SORT_KEY_LIMIT
+ try:
+ m.SORT_KEY_LIMIT = 1
+ m.mergeValues(self.data)
+ self.assertTrue(m.spills >= 1)
+ self.assertEqual(sum(sum(v) for k, v in m.items()), 2 * sum(range(self.N)))
+ finally:
+ m.SORT_KEY_LIMIT = original
+
+
class SorterTests(unittest.TestCase):
def test_in_memory_sort(self):
lst = list(range(1024))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org