You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/12/21 01:37:52 UTC

[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1053897964


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,82 @@
+from pyspark.testing.utils import PySparkTestCase
+from pyspark.ml.torch.distributor import PyTorchDistributor
+from pyspark.sql import SparkSession
+
+# Q: Do you recommend that I use pytest.mark.parametrize? It doesn't seem to be used elsewhere in this code...
+class TestPyTorchDistributor(PySparkTestCase):
+
+    def setUp(self):
+        super().setUp()
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self):
+        super().tearDown()
+
+    def test_validate_correct_inputs(self):
+        inputs = [("pytorch", 1, True, True),
+                  ("pytorch", 100, True, False),
+                  ("pytorch-lightning", 1, False, True),
+                  ("pytorch-lightning", 100, False, False)]
+        for framework, num_processes, local_mode, use_gpu in inputs:
+            with self.subTest():
+                PyTorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_validate_incorrect_inputs(self):
+        inputs = [("tensorflow", 1, True, True, ValueError, "framework"),
+                  ("pytroch", 100, True, False, ValueError, "framework"),
+                  ("pytorchlightning", 1, False, True, ValueError, "framework"),
+                  ("pytorch-lightning", 0, False, False, ValueError, "positive")]
+        for framework, num_processes, local_mode, use_gpu, error, message in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(error, message):
+                    PyTorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_get_correct_num_tasks_when_spark_conf_is_set(self):
+        inputs = [(1, 8, 8),
+                  (2, 8, 4),
+                  (3, 8, 3)]
+        # this is when the sparkconf isn't set
+        for _, num_processes, _ in inputs:
+            with self.subTest():
+                distributor = PyTorchDistributor("pytorch", num_processes, True, True)
+                self.assertEqual(distributor._get_num_tasks(), num_processes)
+        
+        # this is when the sparkconf is set
+        for spark_conf_value, num_processes, expected_output in inputs:
+            with self.subTest():
+                self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount", spark_conf_value)
+                distributor = PyTorchDistributor("pytorch", num_processes, True, True)
+                self.assertEqual(distributor._get_num_tasks(), expected_output)
+
+    def test_encryption_passes(self):
+        input_combination = [("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "true"),
+                             ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "false"),
+                             ("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "true")]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in input_combination:
+            with self.subTest():
+                self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                distributor = PyTorchDistributor("pytorch", 1, True, True)
+                distributor._check_encryption()
+
+    def test_encryption_fails(self):
+        # this is the only combination that should fail
+        input_combination = [("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "false")]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in input_combination:
+            with self.subTest():
+                with self.assertRaisesRegex(Exception, "encryption"):
+                    self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                    self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                    distributor = PyTorchDistributor("pytorch", 1, True, True)
+                    distributor._check_encryption()
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_util import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)

Review Comment:
   Should probably add this file into https://github.com/apache/spark/blob/master/dev/sparktestsupport/modules.py#L605



##########
python/pyspark/ml/torch/utils.py:
##########
@@ -0,0 +1,24 @@
+from pyspark.context import SparkContext

Review Comment:
   license header missing



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org