You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "maddiedawson (via GitHub)" <gi...@apache.org> on 2023/07/12 20:43:05 UTC

[GitHub] [spark] maddiedawson commented on a diff in pull request #41946: [SPARK-44264] FunctionPickler Class

maddiedawson commented on code in PR #41946:
URL: https://github.com/apache/spark/pull/41946#discussion_r1261704748


##########
python/pyspark/ml/tests/test_util.py:
##########
@@ -15,63 +15,188 @@
 # limitations under the License.
 #
 
+from collections.abc import Iterable
+from contextlib import contextmanager
+import os
+from re import A
+import textwrap
+from typing import Iterator
+
 import unittest
 
+from pyspark import cloudpickle
 from pyspark.ml import Pipeline
 from pyspark.ml.classification import LogisticRegression, OneVsRest
 from pyspark.ml.feature import VectorAssembler
 from pyspark.ml.linalg import Vectors
-from pyspark.ml.util import MetaAlgorithmReadWrite
+from pyspark.ml.util import MetaAlgorithmReadWrite, FunctionPickler
 from pyspark.testing.mlutils import SparkSessionTestCase
 
-
 class MetaAlgorithmReadWriteTests(SparkSessionTestCase):
-    def test_getAllNestedStages(self):
-        def _check_uid_set_equal(stages, expected_stages):
-            uids = set(map(lambda x: x.uid, stages))
-            expected_uids = set(map(lambda x: x.uid, expected_stages))
-            self.assertEqual(uids, expected_uids)
-
-        df1 = self.spark.createDataFrame(
-            [
-                (Vectors.dense([1.0, 2.0]), 1.0),
-                (Vectors.dense([-1.0, -2.0]), 0.0),
-            ],
-            ["features", "label"],
-        )
-        df2 = self.spark.createDataFrame(
-            [
-                (1.0, 2.0, 1.0),
-                (1.0, 2.0, 0.0),
-            ],
-            ["a", "b", "label"],
-        )
-        vs = VectorAssembler(inputCols=["a", "b"], outputCol="features")
-        lr = LogisticRegression()
-        pipeline = Pipeline(stages=[vs, lr])
-        pipelineModel = pipeline.fit(df2)
-        ova = OneVsRest(classifier=lr)
-        ovaModel = ova.fit(df1)
-
-        ova_pipeline = Pipeline(stages=[vs, ova])
-        nested_pipeline = Pipeline(stages=[ova_pipeline])
-
-        _check_uid_set_equal(
-            MetaAlgorithmReadWrite.getAllNestedStages(pipeline), [pipeline, vs, lr]
-        )
-        _check_uid_set_equal(
-            MetaAlgorithmReadWrite.getAllNestedStages(pipelineModel),
-            [pipelineModel] + pipelineModel.stages,
-        )
-        _check_uid_set_equal(MetaAlgorithmReadWrite.getAllNestedStages(ova), [ova, lr])
-        _check_uid_set_equal(
-            MetaAlgorithmReadWrite.getAllNestedStages(ovaModel), [ovaModel, lr] + ovaModel.models
-        )
-        _check_uid_set_equal(
-            MetaAlgorithmReadWrite.getAllNestedStages(nested_pipeline),
-            [nested_pipeline, ova_pipeline, vs, ova, lr],
+     def test_getAllNestedStages(self):
+         def _check_uid_set_equal(stages, expected_stages):
+             uids = set(map(lambda x: x.uid, stages))
+             expected_uids = set(map(lambda x: x.uid, expected_stages))
+             self.assertEqual(uids, expected_uids)
+ 
+         df1 = self.spark.createDataFrame(
+             [
+                 (Vectors.dense([1.0, 2.0]), 1.0),
+                 (Vectors.dense([-1.0, -2.0]), 0.0),
+             ],
+             ["features", "label"],
+         )
+         df2 = self.spark.createDataFrame(
+             [
+                 (1.0, 2.0, 1.0),
+                 (1.0, 2.0, 0.0),
+             ],
+             ["a", "b", "label"],
+         )
+         vs = VectorAssembler(inputCols=["a", "b"], outputCol="features")
+         lr = LogisticRegression()
+         pipeline = Pipeline(stages=[vs, lr])
+         pipelineModel = pipeline.fit(df2)
+         ova = OneVsRest(classifier=lr)
+         ovaModel = ova.fit(df1)
+ 
+         ova_pipeline = Pipeline(stages=[vs, ova])
+         nested_pipeline = Pipeline(stages=[ova_pipeline])
+ 
+         _check_uid_set_equal(
+             MetaAlgorithmReadWrite.getAllNestedStages(pipeline), [pipeline, vs, lr]
+         )
+         _check_uid_set_equal(
+             MetaAlgorithmReadWrite.getAllNestedStages(pipelineModel),
+             [pipelineModel] + pipelineModel.stages,
+         )
+         _check_uid_set_equal(MetaAlgorithmReadWrite.getAllNestedStages(ova), [ova, lr])
+         _check_uid_set_equal(
+             MetaAlgorithmReadWrite.getAllNestedStages(ovaModel), [ovaModel, lr] + ovaModel.models
+         )
+         _check_uid_set_equal(
+             MetaAlgorithmReadWrite.getAllNestedStages(nested_pipeline),
+             [nested_pipeline, ova_pipeline, vs, ova, lr],
+         )
+
+# Function that will be used to test pickling.
+def test_function(x: float, y: float) -> float:
+    return x**2 + y**2
+
+class TestFunctionPickler(unittest.TestCase):
+
+    def check_if_test_function_pickled(self, f, og_fn, output_value, *arguments, **key_word_args):
+        fn, args, kwargs = cloudpickle.load(f)
+        self.assertEqual(fn, og_fn)
+        self.assertEqual(args, arguments)
+        self.assertEqual(kwargs, key_word_args)
+        fn_output = fn(*args, **kwargs)
+        self.assertEqual(fn_output, output_value)
+
+    def test_pickle_fn_and_save(self):
+        x, y = 1, 3 # args of test_function
+        tmp_dir = "silly_goose"
+        os.mkdir(tmp_dir)
+        file_path_to_save = "silly_bear"
+        with self.subTest(msg="See if it pickles correctly if no file_path or save_dir are specified"):
+            pickled_fn_path = FunctionPickler.pickle_fn_and_save(test_function, "", "", x, y)
+            with open(pickled_fn_path, "rb") as f:
+                self.check_if_test_function_pickled(f, test_function, 10, x, y)
+            os.remove(pickled_fn_path)
+        with self.subTest(msg="See if pickles correctly and uses file path given as argument"):
+            pickled_fn_path = FunctionPickler.pickle_fn_and_save(test_function, file_path_to_save, "", x, y)
+            self.assertEqual(pickled_fn_path, file_path_to_save)
+            with open(pickled_fn_path, "rb") as f:
+                self.check_if_test_function_pickled(f, test_function, 10, x, y)
+            os.remove(pickled_fn_path)
+        with self.subTest(msg="See if pickles correctly and uses file path despite save_dir being specified"):
+            pickled_fn_path = FunctionPickler.pickle_fn_and_save(test_function, file_path_to_save, tmp_dir, x, y)
+            self.assertEqual(pickled_fn_path, file_path_to_save)
+            with open(pickled_fn_path, "rb") as f:
+                self.check_if_test_function_pickled(f, test_function, 10, x, y)
+            os.remove(pickled_fn_path)
+
+        os.rmdir(tmp_dir)
+
+    def test_getting_output_from_pickle_file(self):
+        a, b = 2, 0
+        pickle_fn_file = FunctionPickler.pickle_fn_and_save(test_function, "", "", a, b)
+        fn, args, kwargs = FunctionPickler.get_func_output(pickle_fn_file)
+        self.assertEqual(fn, test_function)
+        self.assertEqual(len(args), 2)
+        self.assertEqual(len(kwargs), 0)
+        self.assertEqual(args[0], a)
+        self.assertEqual(args[1], b)
+        self.assertEqual(fn(*args, **kwargs), 4)
+        os.remove(pickle_fn_file)
+    
+    @contextmanager
+    def create_reference_file(self, body: str, prefix: str = "", suffix: str = "", fname: str = "reference.py") -> Iterator[None]:
+        try:
+            with open(fname, "w") as f:
+                if prefix != "":
+                    f.write(prefix)
+                f.write(body)
+                if suffix != "":
+                    f.write(suffix)
+            yield
+        finally:
+            os.remove(fname)

Review Comment:
   Oh ok thanks didn't notice the @contextmanager



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org