You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2022/06/05 23:51:00 UTC

[spark] branch master updated: [SPARK-39179][PYTHON][TESTS] Improve the test coverage for pyspark/shuffle.py

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bab70b1ef24 [SPARK-39179][PYTHON][TESTS] Improve the test coverage for pyspark/shuffle.py
bab70b1ef24 is described below

commit bab70b1ef24a2461395b32f609a9274269cb000e
Author: pralabhkumar <pr...@gmail.com>
AuthorDate: Sun Jun 5 16:50:29 2022 -0700

    [SPARK-39179][PYTHON][TESTS] Improve the test coverage for pyspark/shuffle.py
    
    ### What changes were proposed in this pull request?
    This PR add test cases for shuffle.py
    
    ### Why are the changes needed?
    To cover corner test cases and increase coverage. This will increase the coverage of shuffle.py to close to 90%
    
    ### Does this PR introduce _any_ user-facing change?
    No - test only
    
    ### How was this patch tested?
    CI in this PR should test it out
    
    Closes #36701 from pralabhkumar/rk_test_taskcontext.
    
    Authored-by: pralabhkumar <pr...@gmail.com>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 python/pyspark/tests/test_shuffle.py | 94 +++++++++++++++++++++++++++++++++++-
 1 file changed, 93 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/tests/test_shuffle.py b/python/pyspark/tests/test_shuffle.py
index cea29c79357..fb11a84f8af 100644
--- a/python/pyspark/tests/test_shuffle.py
+++ b/python/pyspark/tests/test_shuffle.py
@@ -16,11 +16,20 @@
 #
 import random
 import unittest
+import tempfile
+import os
 
 from py4j.protocol import Py4JJavaError
 
 from pyspark import shuffle, CPickleSerializer, SparkConf, SparkContext
-from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter
+from pyspark.shuffle import (
+    Aggregator,
+    ExternalMerger,
+    ExternalSorter,
+    SimpleAggregator,
+    Merger,
+    ExternalGroupBy,
+)
 
 
 class MergerTests(unittest.TestCase):
@@ -54,6 +63,57 @@ class MergerTests(unittest.TestCase):
         self.assertTrue(m.spills >= 1)
         self.assertEqual(sum(sum(v) for k, v in m.items()), sum(range(self.N)) * 3)
 
+    def test_shuffle_data_with_multiple_locations(self):
+        # SPARK-39179: Test shuffle of data with multiple location also check
+        # shuffle locations get randomized
+
+        with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2:
+            original = os.environ.get("SPARK_LOCAL_DIRS", None)
+            os.environ["SPARK_LOCAL_DIRS"] = tempdir1 + "," + tempdir2
+            try:
+                index_of_tempdir1 = [False, False]
+                for idx in range(10):
+                    m = ExternalMerger(self.agg, 20)
+                    if m.localdirs[0].startswith(tempdir1):
+                        index_of_tempdir1[0] = True
+                    elif m.localdirs[1].startswith(tempdir1):
+                        index_of_tempdir1[1] = True
+                    m.mergeValues(self.data)
+                    self.assertTrue(m.spills >= 1)
+                    self.assertEqual(sum(sum(v) for k, v in m.items()), sum(range(self.N)))
+                self.assertTrue(
+                    index_of_tempdir1[0] and (index_of_tempdir1[0] == index_of_tempdir1[1])
+                )
+            finally:
+                if original is not None:
+                    os.environ["SPARK_LOCAL_DIRS"] = original
+                else:
+                    del os.environ["SPARK_LOCAL_DIRS"]
+
+    def test_simple_aggregator_with_medium_dataset(self):
+        # SPARK-39179: Test Simple aggregator
+        agg = SimpleAggregator(lambda x, y: x + y)
+        m = ExternalMerger(agg, 20)
+        m.mergeValues(self.data)
+        self.assertTrue(m.spills >= 1)
+        self.assertEqual(sum(v for k, v in m.items()), sum(range(self.N)))
+
+    def test_merger_not_implemented_error(self):
+        # SPARK-39179: Test Merger for error scenarios
+        agg = SimpleAggregator(lambda x, y: x + y)
+
+        class DummyMerger(Merger):
+            def __init__(self, agg):
+                Merger.__init__(self, agg)
+
+        dummy_merger = DummyMerger(agg)
+        with self.assertRaises(NotImplementedError):
+            dummy_merger.mergeValues(self.data)
+        with self.assertRaises(NotImplementedError):
+            dummy_merger.mergeCombiners(self.data)
+        with self.assertRaises(NotImplementedError):
+            dummy_merger.items()
+
     def test_huge_dataset(self):
         m = ExternalMerger(self.agg, 5, partitions=3)
         m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10))
@@ -117,6 +177,38 @@ class MergerTests(unittest.TestCase):
             m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data))
 
 
+class ExternalGroupByTests(unittest.TestCase):
+    def setUp(self):
+        self.N = 1 << 20
+        values = [i for i in range(self.N)]
+        keys = [i for i in range(2)]
+        import itertools
+
+        self.data = [value for value in itertools.product(keys, values)]
+        self.agg = Aggregator(
+            lambda x: [x], lambda x, y: x.append(y) or x, lambda x, y: x.extend(y) or x
+        )
+
+    def test_medium_dataset(self):
+        # SPARK-39179: Test external group by for medium dataset
+        m = ExternalGroupBy(self.agg, 5, partitions=3)
+        m.mergeValues(self.data)
+        self.assertTrue(m.spills >= 1)
+        self.assertEqual(sum(sum(v) for k, v in m.items()), 2 * sum(range(self.N)))
+
+    def test_dataset_with_keys_are_unsorted(self):
+        # SPARK-39179: Test external group when numbers of keys are greater than SORT KEY Limit.
+        m = ExternalGroupBy(self.agg, 5, partitions=3)
+        original = m.SORT_KEY_LIMIT
+        try:
+            m.SORT_KEY_LIMIT = 1
+            m.mergeValues(self.data)
+            self.assertTrue(m.spills >= 1)
+            self.assertEqual(sum(sum(v) for k, v in m.items()), 2 * sum(range(self.N)))
+        finally:
+            m.SORT_KEY_LIMIT = original
+
+
 class SorterTests(unittest.TestCase):
     def test_in_memory_sort(self):
         lst = list(range(1024))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org