You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "zhengruifeng (via GitHub)" <gi...@apache.org> on 2023/03/30 10:15:15 UTC

[GitHub] [spark] zhengruifeng opened a new pull request, #40607: [WIP][ML] Make Torch Distributor support Spark Connect

zhengruifeng opened a new pull request, #40607:
URL: https://github.com/apache/spark/pull/40607

   ### What changes were proposed in this pull request?
   
   
   
   ### Why are the changes needed?
   
   
   
   ### Does this PR introduce _any_ user-facing change?
   
   
   ### How was this patch tested?
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155452083


##########
python/pyspark/ml/tests/connect/test_parity_torch_distributor.py:
##########
@@ -0,0 +1,183 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import stat
+import tempfile
+import unittest
+
+have_torch = True
+try:
+    import torch  # noqa: F401
+except ImportError:
+    have_torch = False
+
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+from pyspark.ml.torch.distributor import TorchDistributor
+
+from pyspark.ml.torch.tests.test_distributor import (
+    TorchDistributorBaselineUnitTestsMixin,
+    TorchDistributorLocalUnitTestsMixin,
+    TorchDistributorDistributedUnitTestsMixin,
+    TorchWrapperUnitTestsMixin,
+)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorBaselineUnitTestsOnConnect(
+    TorchDistributorBaselineUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.remote("local[4]").getOrCreate()
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+    @unittest.skip("Can not dynamically set ssl conf via spark.sparkContext._conf.")
+    def test_encryption_passes(self):
+        super().test_encryption_passes()
+
+    @unittest.skip("Can not dynamically set ssl conf via spark.sparkContext._conf.")
+    def test_encryption_fails(self):
+        super().test_encryption_fails()
+
+    def test_get_num_tasks_fails(self) -> None:
+        inputs = [1, 5, 4]
+
+        # This is when the conf isn't set and we request GPUs
+        for num_processes in inputs:
+            with self.subTest():
+                # TODO(SPARK-42994): Support sc.resources
+                # with self.assertRaisesRegex(RuntimeError, "driver"):
+                #     TorchDistributor(num_processes, True, True)
+                with self.assertRaisesRegex(RuntimeError, "unset"):
+                    TorchDistributor(num_processes, False, True)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorLocalUnitTestsOnConnect(
+    TorchDistributorLocalUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.gpu_discovery_script_file = tempfile.NamedTemporaryFile(delete=False)
+        self.gpu_discovery_script_file.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.gpu_discovery_script_file.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.gpu_discovery_script_file.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+
+        self.spark = (
+            SparkSession.builder.appName(class_name)
+            .config("spark.test.home", SPARK_HOME)
+            .config("spark.driver.resource.gpu.amount", "3")
+            .config(
+                "spark.driver.resource.gpu.discoveryScript", self.gpu_discovery_script_file.name
+            )
+            .remote("local-cluster[2,2,1024]")
+            .getOrCreate()
+        )
+
+        self.mnist_dir_path = tempfile.mkdtemp()
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.mnist_dir_path)
+        os.unlink(self.gpu_discovery_script_file.name)
+        self.spark.stop()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_num_tasks_locally(self):
+        super().test_get_num_tasks_locally()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_gpus_owned_local(self):
+        super().test_get_gpus_owned_local()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_local_training_succeeds(self):
+        super().test_local_training_succeeds()
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorDistributedUnitTestsOnConnect(
+    TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.gpu_discovery_script_file = tempfile.NamedTemporaryFile(delete=False)
+        self.gpu_discovery_script_file.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.gpu_discovery_script_file.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.gpu_discovery_script_file.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+        self.spark = (
+            SparkSession.builder.appName(class_name)
+            .config("spark.test.home", SPARK_HOME)
+            .config(
+                "spark.worker.resource.gpu.discoveryScript", self.gpu_discovery_script_file.name
+            )
+            .config("spark.worker.resource.gpu.amount", "3")
+            .config("spark.task.cpus", "2")
+            .config("spark.task.resource.gpu.amount", "1")
+            .config("spark.executor.resource.gpu.amount", "1")
+            .remote("local-cluster[2,2,1024]")

Review Comment:
   ok, let me add one



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155432681


##########
python/pyspark/ml/tests/connect/test_parity_torch_distributor.py:
##########
@@ -0,0 +1,183 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import stat
+import tempfile
+import unittest
+
+have_torch = True
+try:
+    import torch  # noqa: F401
+except ImportError:
+    have_torch = False
+
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+from pyspark.ml.torch.distributor import TorchDistributor
+
+from pyspark.ml.torch.tests.test_distributor import (
+    TorchDistributorBaselineUnitTestsMixin,
+    TorchDistributorLocalUnitTestsMixin,
+    TorchDistributorDistributedUnitTestsMixin,
+    TorchWrapperUnitTestsMixin,
+)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorBaselineUnitTestsOnConnect(
+    TorchDistributorBaselineUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.remote("local[4]").getOrCreate()
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+    @unittest.skip("Can not dynamically set ssl conf via spark.sparkContext._conf.")
+    def test_encryption_passes(self):
+        super().test_encryption_passes()
+
+    @unittest.skip("Can not dynamically set ssl conf via spark.sparkContext._conf.")
+    def test_encryption_fails(self):
+        super().test_encryption_fails()
+
+    def test_get_num_tasks_fails(self) -> None:
+        inputs = [1, 5, 4]
+
+        # This is when the conf isn't set and we request GPUs
+        for num_processes in inputs:
+            with self.subTest():
+                # TODO(SPARK-42994): Support sc.resources
+                # with self.assertRaisesRegex(RuntimeError, "driver"):
+                #     TorchDistributor(num_processes, True, True)
+                with self.assertRaisesRegex(RuntimeError, "unset"):
+                    TorchDistributor(num_processes, False, True)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorLocalUnitTestsOnConnect(
+    TorchDistributorLocalUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.gpu_discovery_script_file = tempfile.NamedTemporaryFile(delete=False)
+        self.gpu_discovery_script_file.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.gpu_discovery_script_file.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.gpu_discovery_script_file.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+
+        self.spark = (
+            SparkSession.builder.appName(class_name)
+            .config("spark.test.home", SPARK_HOME)
+            .config("spark.driver.resource.gpu.amount", "3")
+            .config(
+                "spark.driver.resource.gpu.discoveryScript", self.gpu_discovery_script_file.name
+            )
+            .remote("local-cluster[2,2,1024]")
+            .getOrCreate()
+        )
+
+        self.mnist_dir_path = tempfile.mkdtemp()
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.mnist_dir_path)
+        os.unlink(self.gpu_discovery_script_file.name)
+        self.spark.stop()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_num_tasks_locally(self):
+        super().test_get_num_tasks_locally()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_gpus_owned_local(self):
+        super().test_get_gpus_owned_local()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_local_training_succeeds(self):
+        super().test_local_training_succeeds()
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorDistributedUnitTestsOnConnect(
+    TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.gpu_discovery_script_file = tempfile.NamedTemporaryFile(delete=False)
+        self.gpu_discovery_script_file.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.gpu_discovery_script_file.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.gpu_discovery_script_file.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+        self.spark = (
+            SparkSession.builder.appName(class_name)
+            .config("spark.test.home", SPARK_HOME)
+            .config(
+                "spark.worker.resource.gpu.discoveryScript", self.gpu_discovery_script_file.name
+            )
+            .config("spark.worker.resource.gpu.amount", "3")
+            .config("spark.task.cpus", "2")
+            .config("spark.task.resource.gpu.amount", "1")
+            .config("spark.executor.resource.gpu.amount", "1")
+            .remote("local-cluster[2,2,1024]")

Review Comment:
   no, I find in Connect `builder.config` won't accept a `SparkConf`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1156601296


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -578,19 +600,23 @@ def _run_distributed_training(
         )
         self._check_encryption()
         self.logger.info(
-            f"Started distributed training with {self.num_processes} executor proceses"
+            f"Started distributed training with {self.num_processes} executor processes"
         )
         try:
-            result = (
-                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-                .barrier()
-                .mapPartitions(spark_task_function)
-                .collect()[0]
+            rows = (
+                self.spark.range(start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks)
+                .mapInPandas(func=spark_task_function, schema="chunk binary", barrier=True)
+                .collect()
             )
+            output_bytes = b""
+            for row in rows:
+                output_bytes += row.chunk

Review Comment:
   I test the benchmarks in https://www.guyrutenberg.com/2020/04/04/fast-bytes-concatenation-in-python/  
   and `join` is the fastest



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1156791974


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -578,19 +600,23 @@ def _run_distributed_training(
         )
         self._check_encryption()
         self.logger.info(
-            f"Started distributed training with {self.num_processes} executor proceses"
+            f"Started distributed training with {self.num_processes} executor processes"
         )
         try:
-            result = (
-                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-                .barrier()
-                .mapPartitions(spark_task_function)
-                .collect()[0]
+            rows = (
+                self.spark.range(start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks)
+                .mapInPandas(func=spark_task_function, schema="chunk binary", barrier=True)
+                .collect()
             )
+            output_bytes = b""
+            for row in rows:
+                output_bytes += row.chunk

Review Comment:
   > join cause memory issue on master ci
   What's the reason ? 



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -578,19 +600,23 @@ def _run_distributed_training(
         )
         self._check_encryption()
         self.logger.info(
-            f"Started distributed training with {self.num_processes} executor proceses"
+            f"Started distributed training with {self.num_processes} executor processes"
         )
         try:
-            result = (
-                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-                .barrier()
-                .mapPartitions(spark_task_function)
-                .collect()[0]
+            rows = (
+                self.spark.range(start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks)
+                .mapInPandas(func=spark_task_function, schema="chunk binary", barrier=True)
+                .collect()
             )
+            output_bytes = b""
+            for row in rows:
+                output_bytes += row.chunk

Review Comment:
   > join cause memory issue on master ci
   
   What's the reason ? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1156934149


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -578,19 +600,23 @@ def _run_distributed_training(
         )
         self._check_encryption()
         self.logger.info(
-            f"Started distributed training with {self.num_processes} executor proceses"
+            f"Started distributed training with {self.num_processes} executor processes"
         )
         try:
-            result = (
-                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-                .barrier()
-                .mapPartitions(spark_task_function)
-                .collect()[0]
+            rows = (
+                self.spark.range(start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks)
+                .mapInPandas(func=spark_task_function, schema="chunk binary", barrier=True)
+                .collect()
             )
+            output_bytes = b""
+            for row in rows:
+                output_bytes += row.chunk

Review Comment:
   But your test only runs with small dataset.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155435380


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -554,7 +568,7 @@ def set_gpus(context: "BarrierTaskContext") -> None:
                     pass
 
             if context.partitionId() == 0:
-                yield output
+                yield pd.DataFrame(data={"output": [cloudpickle.dumps(output)]})

Review Comment:
   > since the output dataframe only contains one row
   
   We can make it contains multiple rows, this is doable.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [WIP][ML] Make Torch Distributor support Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1153072179


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -581,11 +593,11 @@ def _run_distributed_training(
             f"Started distributed training with {self.num_processes} executor proceses"
         )
         try:
+            assert self.spark is not None
             result = (
-                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-                .barrier()
-                .mapPartitions(spark_task_function)
-                .collect()[0]
+                self.spark.range(start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks)
+                .mapInPandas(func=spark_task_function, schema="output binary", barrier=True)
+                .first()["output"]

Review Comment:
   got it, will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155428117


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -32,54 +32,65 @@
 
 from pyspark import cloudpickle
 from pyspark.sql import SparkSession
+from pyspark.taskcontext import BarrierTaskContext
 from pyspark.ml.torch.log_communication import (  # type: ignore
-    get_driver_host,
     LogStreamingClient,
     LogStreamingServer,
 )
-from pyspark.context import SparkContext
-from pyspark.taskcontext import BarrierTaskContext
 
 
-# TODO(SPARK-41589): will move the functions and tests to an external file
-#       once we are in agreement about which functions should be in utils.py
-def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
-    """Get the conf "key" from the given spark context,
+def _get_active_session() -> SparkSession:
+    from pyspark.sql.utils import is_remote
+
+    if not is_remote():
+        spark = SparkSession.getActiveSession()
+    else:
+        from pyspark.sql.connect.session import _active_spark_session
+
+        spark = _active_spark_session  # type: ignore[assignment]

Review Comment:
   ```suggestion
           import pyspark.sql.connect.session
           spark = pyspark.sql.connect.session._active_spark_session  # type: ignore[assignment]
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] Yikun commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "Yikun (via GitHub)" <gi...@apache.org>.
Yikun commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1159456107


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -52,13 +52,18 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
     session.withActive {
 
       // Add debug information to the query execution so that the jobs are traceable.
-      val debugString = v.toString
-      session.sparkContext.setLocalProperty(
-        "callSite.short",
-        s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}")
-      session.sparkContext.setLocalProperty(
-        "callSite.long",
-        StringUtils.abbreviate(debugString, 2048))
+      try {

Review Comment:
   https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners#supported-runners-and-hardware-resources
   
   Yes, 2U7G, and taking into account the system occupancy, 2U 6+G available:
   https://github.com/apache/spark/blob/f541301b7680d96611796d92943d4ec72c71ec0d/.github/workflows/build_and_test.yml#L1028-L1029



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1156791129


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -554,7 +561,21 @@ def set_gpus(context: "BarrierTaskContext") -> None:
                     pass
 
             if context.partitionId() == 0:
-                yield output
+                output_bytes = cloudpickle.dumps(output)
+                output_size = len(output_bytes)
+
+                # In Spark Connect, DataFrame.collect stacks rows to size
+                # 'spark.connect.grpc.arrow.maxBatchSize' (default 4MiB),
+                # here use 4KiB for each chunk, which mean each arrow batch
+                # may contain about 1000 chunks.
+                chunks = []
+                chunk_size = 4096

Review Comment:
   Make sense.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [WIP][ML] Make Torch Distributor support Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1153067103


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -330,6 +340,7 @@ def __init__(
         num_processes: int = 1,
         local_mode: bool = True,
         use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,

Review Comment:
   Do we need this ? Why not using active session ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155432023


##########
python/pyspark/ml/tests/connect/test_parity_torch_distributor.py:
##########
@@ -0,0 +1,183 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import stat
+import tempfile
+import unittest
+
+have_torch = True
+try:
+    import torch  # noqa: F401
+except ImportError:
+    have_torch = False
+
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+from pyspark.ml.torch.distributor import TorchDistributor
+
+from pyspark.ml.torch.tests.test_distributor import (
+    TorchDistributorBaselineUnitTestsMixin,
+    TorchDistributorLocalUnitTestsMixin,
+    TorchDistributorDistributedUnitTestsMixin,
+    TorchWrapperUnitTestsMixin,
+)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorBaselineUnitTestsOnConnect(
+    TorchDistributorBaselineUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.remote("local[4]").getOrCreate()
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+    @unittest.skip("Can not dynamically set ssl conf via spark.sparkContext._conf.")
+    def test_encryption_passes(self):
+        super().test_encryption_passes()
+
+    @unittest.skip("Can not dynamically set ssl conf via spark.sparkContext._conf.")
+    def test_encryption_fails(self):
+        super().test_encryption_fails()
+
+    def test_get_num_tasks_fails(self) -> None:
+        inputs = [1, 5, 4]
+
+        # This is when the conf isn't set and we request GPUs
+        for num_processes in inputs:
+            with self.subTest():
+                # TODO(SPARK-42994): Support sc.resources
+                # with self.assertRaisesRegex(RuntimeError, "driver"):
+                #     TorchDistributor(num_processes, True, True)
+                with self.assertRaisesRegex(RuntimeError, "unset"):
+                    TorchDistributor(num_processes, False, True)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorLocalUnitTestsOnConnect(
+    TorchDistributorLocalUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.gpu_discovery_script_file = tempfile.NamedTemporaryFile(delete=False)
+        self.gpu_discovery_script_file.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.gpu_discovery_script_file.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.gpu_discovery_script_file.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+
+        self.spark = (
+            SparkSession.builder.appName(class_name)
+            .config("spark.test.home", SPARK_HOME)
+            .config("spark.driver.resource.gpu.amount", "3")
+            .config(
+                "spark.driver.resource.gpu.discoveryScript", self.gpu_discovery_script_file.name
+            )
+            .remote("local-cluster[2,2,1024]")
+            .getOrCreate()
+        )
+
+        self.mnist_dir_path = tempfile.mkdtemp()
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.mnist_dir_path)
+        os.unlink(self.gpu_discovery_script_file.name)
+        self.spark.stop()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_num_tasks_locally(self):
+        super().test_get_num_tasks_locally()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_gpus_owned_local(self):
+        super().test_get_gpus_owned_local()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_local_training_succeeds(self):
+        super().test_local_training_succeeds()
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorDistributedUnitTestsOnConnect(
+    TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.gpu_discovery_script_file = tempfile.NamedTemporaryFile(delete=False)
+        self.gpu_discovery_script_file.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.gpu_discovery_script_file.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.gpu_discovery_script_file.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+        self.spark = (
+            SparkSession.builder.appName(class_name)
+            .config("spark.test.home", SPARK_HOME)
+            .config(
+                "spark.worker.resource.gpu.discoveryScript", self.gpu_discovery_script_file.name
+            )
+            .config("spark.worker.resource.gpu.amount", "3")
+            .config("spark.task.cpus", "2")
+            .config("spark.task.resource.gpu.amount", "1")
+            .config("spark.executor.resource.gpu.amount", "1")
+            .remote("local-cluster[2,2,1024]")

Review Comment:
   Can we reuse the similar code in spark.ml side ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155449960


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -32,54 +32,65 @@
 
 from pyspark import cloudpickle
 from pyspark.sql import SparkSession
+from pyspark.taskcontext import BarrierTaskContext
 from pyspark.ml.torch.log_communication import (  # type: ignore
-    get_driver_host,
     LogStreamingClient,
     LogStreamingServer,
 )
-from pyspark.context import SparkContext
-from pyspark.taskcontext import BarrierTaskContext
 
 
-# TODO(SPARK-41589): will move the functions and tests to an external file
-#       once we are in agreement about which functions should be in utils.py
-def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
-    """Get the conf "key" from the given spark context,
+def _get_active_session() -> SparkSession:
+    from pyspark.sql.utils import is_remote
+
+    if not is_remote():
+        spark = SparkSession.getActiveSession()
+    else:
+        from pyspark.sql.connect.session import _active_spark_session
+
+        spark = _active_spark_session  # type: ignore[assignment]
+
+    if spark is None:
+        raise RuntimeError("An active SparkSession is required for the distributor.")
+    return spark
+
+
+def _get_conf(spark: SparkSession, key: str, default_value: str) -> str:
+    """Get the conf "key" from the given spark session,
     or return the default value if the conf is not set.
-    This expects the conf value to be a boolean or string;
-    if the value is a string, this checks for all capitalization
-    patterns of "true" and "false" to match Scala.
+    If this session is a remote connect session, the SparkConf
+    code path always fail, and fallback to the RuntimeConf.
 
     Parameters
     ----------
-    sc : :class:`SparkContext`
-        The :class:`SparkContext` for the distributor.
+    spark : :class:`SparkSession`
+        The :class:`SparkSession` for the distributor.
     key : str
         string for conf name
     default_value : str
         default value for the conf value for the given key
 
     Returns
     -------
-    bool
-        Returns the boolean value that corresponds to the conf
-
-    Raises
-    ------
-    ValueError
-        Thrown when the conf value is not a valid boolean
+    str
+        Returns the string value that corresponds to the conf
     """
-    val = sc.getConf().get(key, default_value)
-    lowercase_val = val.lower()
-    if lowercase_val == "true":
-        return True
-    if lowercase_val == "false":
-        return False
-    raise ValueError(
-        f"The conf value for '{key}' was expected to be a boolean "
-        f"value but found value of type {type(val)} "
-        f"with value: {val}"
-    )
+    from pyspark.sql.utils import is_remote
+
+    if not is_remote():
+        value = spark.sparkContext.getConf().get(key, default_value)

Review Comment:
   seems we can directly use `RuntimeConf` instead, let me have a try



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1156563413


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -578,19 +600,23 @@ def _run_distributed_training(
         )
         self._check_encryption()
         self.logger.info(
-            f"Started distributed training with {self.num_processes} executor proceses"
+            f"Started distributed training with {self.num_processes} executor processes"
         )
         try:
-            result = (
-                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-                .barrier()
-                .mapPartitions(spark_task_function)
-                .collect()[0]
+            rows = (
+                self.spark.range(start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks)
+                .mapInPandas(func=spark_task_function, schema="chunk binary", barrier=True)
+                .collect()
             )
+            output_bytes = b""
+            for row in rows:
+                output_bytes += row.chunk

Review Comment:
   Did you test the performance of concat them ?
   @HyukjinKwon Do you have better approach to concat them with better performance ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1156563103


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -554,7 +561,21 @@ def set_gpus(context: "BarrierTaskContext") -> None:
                     pass
 
             if context.partitionId() == 0:
-                yield output
+                output_bytes = cloudpickle.dumps(output)
+                output_size = len(output_bytes)
+
+                # In Spark Connect, DataFrame.collect stacks rows to size
+                # 'spark.connect.grpc.arrow.maxBatchSize' (default 4MiB),
+                # here use 4KiB for each chunk, which mean each arrow batch
+                # may contain about 1000 chunks.
+                chunks = []
+                chunk_size = 4096

Review Comment:
   I think we should use larger chunk size, I recommend 256MB .



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1156574015


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -554,7 +561,21 @@ def set_gpus(context: "BarrierTaskContext") -> None:
                     pass
 
             if context.partitionId() == 0:
-                yield output
+                output_bytes = cloudpickle.dumps(output)
+                output_size = len(output_bytes)
+
+                # In Spark Connect, DataFrame.collect stacks rows to size
+                # 'spark.connect.grpc.arrow.maxBatchSize' (default 4MiB),
+                # here use 4KiB for each chunk, which mean each arrow batch
+                # may contain about 1000 chunks.
+                chunks = []
+                chunk_size = 4096

Review Comment:
   the arrow batch size is actually controlled by `spark.connect.grpc.arrow.maxBatchSize`, and if it exceeds the max size of grpc message it fails.
   So I think we don't need a larger chunk size here



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155435075


##########
python/pyspark/ml/tests/connect/test_parity_torch_distributor.py:
##########
@@ -0,0 +1,183 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import stat
+import tempfile
+import unittest
+
+have_torch = True
+try:
+    import torch  # noqa: F401
+except ImportError:
+    have_torch = False
+
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+from pyspark.ml.torch.distributor import TorchDistributor
+
+from pyspark.ml.torch.tests.test_distributor import (
+    TorchDistributorBaselineUnitTestsMixin,
+    TorchDistributorLocalUnitTestsMixin,
+    TorchDistributorDistributedUnitTestsMixin,
+    TorchWrapperUnitTestsMixin,
+)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorBaselineUnitTestsOnConnect(
+    TorchDistributorBaselineUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.remote("local[4]").getOrCreate()
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+    @unittest.skip("Can not dynamically set ssl conf via spark.sparkContext._conf.")
+    def test_encryption_passes(self):
+        super().test_encryption_passes()
+
+    @unittest.skip("Can not dynamically set ssl conf via spark.sparkContext._conf.")
+    def test_encryption_fails(self):
+        super().test_encryption_fails()
+
+    def test_get_num_tasks_fails(self) -> None:
+        inputs = [1, 5, 4]
+
+        # This is when the conf isn't set and we request GPUs
+        for num_processes in inputs:
+            with self.subTest():
+                # TODO(SPARK-42994): Support sc.resources
+                # with self.assertRaisesRegex(RuntimeError, "driver"):
+                #     TorchDistributor(num_processes, True, True)
+                with self.assertRaisesRegex(RuntimeError, "unset"):
+                    TorchDistributor(num_processes, False, True)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorLocalUnitTestsOnConnect(
+    TorchDistributorLocalUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.gpu_discovery_script_file = tempfile.NamedTemporaryFile(delete=False)
+        self.gpu_discovery_script_file.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.gpu_discovery_script_file.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.gpu_discovery_script_file.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+
+        self.spark = (
+            SparkSession.builder.appName(class_name)
+            .config("spark.test.home", SPARK_HOME)
+            .config("spark.driver.resource.gpu.amount", "3")
+            .config(
+                "spark.driver.resource.gpu.discoveryScript", self.gpu_discovery_script_file.name
+            )
+            .remote("local-cluster[2,2,1024]")
+            .getOrCreate()
+        )
+
+        self.mnist_dir_path = tempfile.mkdtemp()
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.mnist_dir_path)
+        os.unlink(self.gpu_discovery_script_file.name)
+        self.spark.stop()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_num_tasks_locally(self):
+        super().test_get_num_tasks_locally()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_gpus_owned_local(self):
+        super().test_get_gpus_owned_local()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_local_training_succeeds(self):
+        super().test_local_training_succeeds()
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorDistributedUnitTestsOnConnect(
+    TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.gpu_discovery_script_file = tempfile.NamedTemporaryFile(delete=False)
+        self.gpu_discovery_script_file.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.gpu_discovery_script_file.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.gpu_discovery_script_file.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+        self.spark = (
+            SparkSession.builder.appName(class_name)
+            .config("spark.test.home", SPARK_HOME)
+            .config(
+                "spark.worker.resource.gpu.discoveryScript", self.gpu_discovery_script_file.name
+            )
+            .config("spark.worker.resource.gpu.amount", "3")
+            .config("spark.task.cpus", "2")
+            .config("spark.task.resource.gpu.amount", "1")
+            .config("spark.executor.resource.gpu.amount", "1")
+            .remote("local-cluster[2,2,1024]")

Review Comment:
   We can still write a helper function for this part, only small difference on its arguments



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155446538


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -554,7 +568,7 @@ def set_gpus(context: "BarrierTaskContext") -> None:
                     pass
 
             if context.partitionId() == 0:
-                yield output
+                yield pd.DataFrame(data={"output": [cloudpickle.dumps(output)]})

Review Comment:
   sounds good, break the row into multi rows in partition 0, and then concat them in ~~driver~~ python client



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155439978


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -32,54 +32,65 @@
 
 from pyspark import cloudpickle
 from pyspark.sql import SparkSession
+from pyspark.taskcontext import BarrierTaskContext
 from pyspark.ml.torch.log_communication import (  # type: ignore
-    get_driver_host,
     LogStreamingClient,
     LogStreamingServer,
 )
-from pyspark.context import SparkContext
-from pyspark.taskcontext import BarrierTaskContext
 
 
-# TODO(SPARK-41589): will move the functions and tests to an external file
-#       once we are in agreement about which functions should be in utils.py
-def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
-    """Get the conf "key" from the given spark context,
+def _get_active_session() -> SparkSession:
+    from pyspark.sql.utils import is_remote
+
+    if not is_remote():
+        spark = SparkSession.getActiveSession()
+    else:
+        from pyspark.sql.connect.session import _active_spark_session
+
+        spark = _active_spark_session  # type: ignore[assignment]
+
+    if spark is None:
+        raise RuntimeError("An active SparkSession is required for the distributor.")
+    return spark
+
+
+def _get_conf(spark: SparkSession, key: str, default_value: str) -> str:
+    """Get the conf "key" from the given spark session,
     or return the default value if the conf is not set.
-    This expects the conf value to be a boolean or string;
-    if the value is a string, this checks for all capitalization
-    patterns of "true" and "false" to match Scala.
+    If this session is a remote connect session, the SparkConf
+    code path always fail, and fallback to the RuntimeConf.
 
     Parameters
     ----------
-    sc : :class:`SparkContext`
-        The :class:`SparkContext` for the distributor.
+    spark : :class:`SparkSession`
+        The :class:`SparkSession` for the distributor.
     key : str
         string for conf name
     default_value : str
         default value for the conf value for the given key
 
     Returns
     -------
-    bool
-        Returns the boolean value that corresponds to the conf
-
-    Raises
-    ------
-    ValueError
-        Thrown when the conf value is not a valid boolean
+    str
+        Returns the string value that corresponds to the conf
     """
-    val = sc.getConf().get(key, default_value)
-    lowercase_val = val.lower()
-    if lowercase_val == "true":
-        return True
-    if lowercase_val == "false":
-        return False
-    raise ValueError(
-        f"The conf value for '{key}' was expected to be a boolean "
-        f"value but found value of type {type(val)} "
-        f"with value: {val}"
-    )
+    from pyspark.sql.utils import is_remote
+
+    if not is_remote():
+        value = spark.sparkContext.getConf().get(key, default_value)

Review Comment:
   we can, but there will be many session launch/stop in UT



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155481805


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -117,21 +117,12 @@ def train_fn(learning_rate: float) -> Any:
                 optimizer.step()
             print(f"epoch {epoch} finished.")
 
-        return "success"
+        return "success" * (1 << 20)

Review Comment:
   increase the model size to test the `model split` code path



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [WIP][ML] Make Torch Distributor support Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1153069493


##########
python/pyspark/ml/tests/connect/test_parity_torch_distributor.py:
##########
@@ -0,0 +1,511 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import contextlib
+import os
+import shutil
+from six import StringIO
+import stat
+import subprocess
+import sys
+import time
+import tempfile
+import threading
+from typing import Callable, Dict, Any
+import unittest
+from unittest.mock import patch
+
+have_torch = True
+try:
+    import torch  # noqa: F401
+except ImportError:
+    have_torch = False
+
+from pyspark import SparkConf
+from pyspark.ml.torch.distributor import TorchDistributor
+from pyspark.ml.torch.torch_run_process_wrapper import clean_and_terminate, check_parent_alive
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+
+
+@contextlib.contextmanager
+def patch_stdout() -> StringIO:
+    """patch stdout and give an output"""
+    sys_stdout = sys.stdout
+    io_out = StringIO()
+    sys.stdout = io_out
+    try:
+        yield io_out
+    finally:
+        sys.stdout = sys_stdout
+
+
+def create_training_function(mnist_dir_path: str) -> Callable:
+    import torch.nn as nn
+    import torch.nn.functional as F
+    from torchvision import transforms, datasets
+
+    batch_size = 100
+    num_epochs = 1
+    momentum = 0.5
+
+    train_dataset = datasets.MNIST(
+        mnist_dir_path,
+        train=True,
+        download=True,
+        transform=transforms.Compose(
+            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
+        ),
+    )
+
+    class Net(nn.Module):
+        def __init__(self) -> None:
+            super(Net, self).__init__()
+            self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
+            self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
+            self.conv2_drop = nn.Dropout2d()
+            self.fc1 = nn.Linear(320, 50)
+            self.fc2 = nn.Linear(50, 10)
+
+        def forward(self, x: Any) -> Any:
+            x = F.relu(F.max_pool2d(self.conv1(x), 2))
+            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
+            x = x.view(-1, 320)
+            x = F.relu(self.fc1(x))
+            x = F.dropout(x, training=self.training)
+            x = self.fc2(x)
+            return F.log_softmax(x)
+
+    def train_fn(learning_rate: float) -> Any:
+        import torch
+        import torch.optim as optim
+        import torch.distributed as dist
+        from torch.nn.parallel import DistributedDataParallel as DDP
+        from torch.utils.data.distributed import DistributedSampler
+
+        dist.init_process_group("gloo")
+
+        train_sampler = DistributedSampler(dataset=train_dataset)
+        data_loader = torch.utils.data.DataLoader(
+            train_dataset, batch_size=batch_size, sampler=train_sampler
+        )
+
+        model = Net()
+        ddp_model = DDP(model)
+        optimizer = optim.SGD(ddp_model.parameters(), lr=learning_rate, momentum=momentum)
+        for epoch in range(1, num_epochs + 1):
+            model.train()
+            for _, (data, target) in enumerate(data_loader):
+                optimizer.zero_grad()
+                output = model(data)
+                loss = F.nll_loss(output, target)
+                loss.backward()
+                optimizer.step()
+            print(f"epoch {epoch} finished.")
+
+        return "success"
+
+    return train_fn
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorBaselineUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.remote("local[4]").getOrCreate()
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+    def setup_env_vars(self, input_map: Dict[str, str]) -> None:
+        for key, value in input_map.items():
+            os.environ[key] = value
+
+    def delete_env_vars(self, input_map: Dict[str, str]) -> None:
+        for key in input_map.keys():
+            del os.environ[key]
+
+    def test_validate_correct_inputs(self) -> None:
+        inputs = [
+            (1, True, False),
+            (100, True, False),
+            (1, False, False),
+            (100, False, False),
+        ]
+        for num_processes, local_mode, use_gpu in inputs:
+            with self.subTest():
+                expected_params = {
+                    "num_processes": num_processes,
+                    "local_mode": local_mode,
+                    "use_gpu": use_gpu,
+                    "num_tasks": num_processes,
+                }
+                dist = TorchDistributor(num_processes, local_mode, use_gpu, spark=self.spark)
+                self.assertEqual(expected_params, dist.input_params)
+
+    def test_validate_incorrect_inputs(self) -> None:
+        inputs = [
+            (0, False, False, ValueError, "positive"),
+        ]
+        for num_processes, local_mode, use_gpu, error, message in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(error, message):
+                    TorchDistributor(num_processes, local_mode, use_gpu)
+
+    # TODO: Should support read SparkConf and initialize Remote Session with SSL options
+    # def test_encryption_passes(self) -> None:
+    #     inputs = [
+    #         ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "true"),
+    #         ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "false"),
+    #         ("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "true"),
+    #     ]
+    #     for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
+    #         with self.subTest():
+    #             self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+    #             self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+    #             distributor = TorchDistributor(1, True, False)
+    #             distributor._check_encryption()
+
+    # def test_encryption_fails(self) -> None:
+    #     # this is the only combination that should fail
+    #     inputs = [("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "false")]
+    #     for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
+    #         with self.subTest():
+    #             with self.assertRaisesRegex(Exception, "encryption"):
+    #                 self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+    #                 self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+    #                 distributor = TorchDistributor(1, True, False)
+    #                 distributor._check_encryption()
+
+    def test_get_num_tasks_fails(self) -> None:
+        inputs = [1, 5, 4]
+
+        # This is when the conf isn't set and we request GPUs
+        for num_processes in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(RuntimeError, "driver"):
+                    TorchDistributor(num_processes, True, True)
+                with self.assertRaisesRegex(RuntimeError, "unset"):
+                    TorchDistributor(num_processes, False, True)
+
+    def test_execute_command(self) -> None:
+        """Test that run command runs the process and logs are written correctly"""
+
+        with patch_stdout() as output:
+            stdout_command = ["echo", "hello_stdout"]
+            TorchDistributor._execute_command(stdout_command)
+            self.assertIn(
+                "hello_stdout", output.getvalue().strip(), "hello_stdout should print to stdout"
+            )
+
+        with patch_stdout() as output:
+            stderr_command = ["bash", "-c", "echo hello_stderr >&2"]
+            TorchDistributor._execute_command(stderr_command)
+            self.assertIn(
+                "hello_stderr", output.getvalue().strip(), "hello_stderr should print to stdout"
+            )
+
+        # include command in the exception message
+        with self.assertRaisesRegex(RuntimeError, "exit 1"):
+            error_command = ["bash", "-c", "exit 1"]
+            TorchDistributor._execute_command(error_command)
+
+        with self.assertRaisesRegex(RuntimeError, "abcdef"):
+            error_command = ["bash", "-c", "'abc''def'"]
+            TorchDistributor._execute_command(error_command)
+
+    def test_create_torchrun_command(self) -> None:
+        train_path = "train.py"
+        args_string = ["1", "3"]
+        local_mode_input_params = {"num_processes": 4, "local_mode": True}
+
+        expected_local_mode_output = [
+            sys.executable,
+            "-m",
+            "pyspark.ml.torch.torch_run_process_wrapper",
+            "--standalone",
+            "--nnodes=1",
+            "--nproc_per_node=4",
+            "train.py",
+            "1",
+            "3",
+        ]
+        self.assertEqual(
+            TorchDistributor._create_torchrun_command(
+                local_mode_input_params, train_path, *args_string
+            ),
+            expected_local_mode_output,
+        )
+
+        distributed_mode_input_params = {"num_processes": 4, "local_mode": False}
+        input_env_vars = {"MASTER_ADDR": "localhost", "MASTER_PORT": "9350", "RANK": "3"}
+
+        args_number = [1, 3]  # testing conversion to strings
+        self.setup_env_vars(input_env_vars)
+        expected_distributed_mode_output = [
+            sys.executable,
+            "-m",
+            "pyspark.ml.torch.torch_run_process_wrapper",
+            "--nnodes=4",
+            "--node_rank=3",
+            "--rdzv_endpoint=localhost:9350",
+            "--rdzv_id=0",
+            "--nproc_per_node=1",
+            "train.py",
+            "1",
+            "3",
+        ]
+        self.assertEqual(
+            TorchDistributor._create_torchrun_command(
+                distributed_mode_input_params, train_path, *args_number
+            ),
+            expected_distributed_mode_output,
+        )
+        self.delete_env_vars(input_env_vars)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorLocalUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.gpu_discovery_script_file = tempfile.NamedTemporaryFile(delete=False)
+        self.gpu_discovery_script_file.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.gpu_discovery_script_file.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.gpu_discovery_script_file.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+
+        self.spark = (
+            SparkSession.builder.appName(class_name)
+            .config("spark.test.home", SPARK_HOME)
+            .config("spark.driver.resource.gpu.amount", "3")
+            .config(
+                "spark.driver.resource.gpu.discoveryScript", self.gpu_discovery_script_file.name
+            )
+            .remote("local-cluster[2,2,1024]")
+            .getOrCreate()
+        )
+
+        self.mnist_dir_path = tempfile.mkdtemp()
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.mnist_dir_path)
+        os.unlink(self.gpu_discovery_script_file.name)
+        self.spark.stop()
+
+    def setup_env_vars(self, input_map: Dict[str, str]) -> None:
+        for key, value in input_map.items():
+            os.environ[key] = value
+
+    def delete_env_vars(self, input_map: Dict[str, str]) -> None:
+        for key in input_map.keys():
+            del os.environ[key]
+
+    def test_get_num_tasks_locally(self) -> None:
+        succeeds = [1, 2]
+        fails = [4, 8]
+        for num_processes in succeeds:
+            with self.subTest():
+                expected_output = num_processes
+                distributor = TorchDistributor(num_processes, True, True, spark=self.spark)
+                self.assertEqual(distributor._get_num_tasks(), expected_output)
+
+        for num_processes in fails:
+            with self.subTest():
+                with self.assertLogs("TorchDistributor", level="WARNING") as log:
+                    distributor = TorchDistributor(num_processes, True, True, spark=self.spark)
+                    self.assertEqual(len(log.records), 1)
+                    self.assertEqual(distributor.num_processes, 3)
+
+    # TODO: support SparkConf
+    # def test_get_gpus_owned_local(self) -> None:
+    #     addresses = ["0", "1", "2"]
+    #     self.assertEqual(get_gpus_owned(self.sc), addresses)
+    #
+    #     env_vars = {"CUDA_VISIBLE_DEVICES": "3,4,5"}
+    #     self.setup_env_vars(env_vars)
+    #     self.assertEqual(get_gpus_owned(self.sc), ["3", "4", "5"])
+    #     self.delete_env_vars(env_vars)
+
+    def test_local_training_succeeds(self) -> None:
+        CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
+        inputs = [
+            ("0,1,2", 1, True, "1"),
+            ("0,1,2", 3, True, "1,2,0"),
+            ("0,1,2", 2, False, "0,1,2"),
+            (None, 3, False, "NONE"),
+        ]
+
+        for i, (cuda_env_var, num_processes, use_gpu, expected) in enumerate(inputs):
+            with self.subTest(f"subtest: {i + 1}"):
+                # setup
+                if cuda_env_var:
+                    self.setup_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var})
+
+                dist = TorchDistributor(num_processes, True, use_gpu, spark=self.spark)
+                dist._run_training_on_pytorch_file = lambda *args: os.environ.get(
+                    CUDA_VISIBLE_DEVICES, "NONE"
+                )
+                self.assertEqual(
+                    expected,
+                    dist._run_local_training(dist._run_training_on_pytorch_file, "train.py"),
+                )
+                # cleanup
+                if cuda_env_var:
+                    self.delete_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var})
+
+    def test_local_file_with_pytorch(self) -> None:
+        spark = SparkSession.builder.remote("local[4]").getOrCreate()
+        test_file_path = "python/test_support/test_pytorch_training_file.py"
+        learning_rate_str = "0.01"
+        TorchDistributor(num_processes=2, local_mode=True, use_gpu=False, spark=spark).run(
+            test_file_path, learning_rate_str
+        )
+
+    def test_end_to_end_run_locally(self) -> None:
+        spark = SparkSession.builder.remote("local[4]").getOrCreate()
+        train_fn = create_training_function(self.mnist_dir_path)
+        output = TorchDistributor(num_processes=2, local_mode=True, use_gpu=False, spark=spark).run(
+            train_fn, 0.001
+        )
+        self.assertEqual(output, "success")
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorDistributedUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.gpu_discovery_script_file = tempfile.NamedTemporaryFile(delete=False)
+        self.gpu_discovery_script_file.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.gpu_discovery_script_file.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.gpu_discovery_script_file.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+        conf = SparkConf().set("spark.test.home", SPARK_HOME)
+
+        conf = conf.set(
+            "spark.worker.resource.gpu.discoveryScript", self.gpu_discovery_script_file.name
+        )
+        conf = conf.set("spark.worker.resource.gpu.amount", "3")
+        conf = conf.set("spark.task.cpus", "2")
+        conf = conf.set("spark.task.resource.gpu.amount", "1")
+        conf = conf.set("spark.executor.resource.gpu.amount", "1")
+
+        self.spark = (
+            SparkSession.builder.appName(class_name)
+            .config(conf=conf)
+            .remote("local-cluster[2,2,1024]")
+            .getOrCreate()
+        )
+        self.mnist_dir_path = tempfile.mkdtemp()
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.mnist_dir_path)
+        os.unlink(self.gpu_discovery_script_file.name)
+        self.spark.stop()
+
+    def test_dist_training_succeeds(self) -> None:
+        CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
+        inputs = [
+            ("0,1,2", 2, True, "0"),
+        ]
+
+        for i, (_, num_processes, use_gpu, expected) in enumerate(inputs):
+            with self.subTest(f"subtest: {i + 1}"):
+                dist = TorchDistributor(num_processes, False, use_gpu, spark=self.spark)
+                dist._run_training_on_pytorch_file = lambda *args: os.environ.get(
+                    CUDA_VISIBLE_DEVICES, "NONE"
+                )
+                self.assertEqual(
+                    expected,
+                    dist._run_distributed_training(dist._run_training_on_pytorch_file, "..."),
+                )
+
+    # Can not set SparkConf
+    # def test_get_num_tasks_distributed(self) -> None:
+    #     inputs = [(1, 8, 8), (2, 8, 4), (3, 8, 3)]
+    #
+    #     for spark_conf_value, num_processes, expected_output in inputs:
+    #         with self.subTest():
+    #             self.spark.sparkContext._conf.set(
+    #                 "spark.task.resource.gpu.amount", str(spark_conf_value)
+    #             )
+    #             distributor = TorchDistributor(num_processes, False, True, self.spark)
+    #             self.assertEqual(distributor._get_num_tasks(), expected_output)
+    #
+    #     self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount", "1")
+
+    def test_distributed_file_with_pytorch(self) -> None:
+        test_file_path = "python/test_support/test_pytorch_training_file.py"
+        learning_rate_str = "0.01"
+        TorchDistributor(num_processes=2, local_mode=False, use_gpu=False, spark=self.spark).run(
+            test_file_path, learning_rate_str
+        )
+
+    def test_end_to_end_run_distributedly(self) -> None:
+        train_fn = create_training_function(self.mnist_dir_path)
+        output = TorchDistributor(
+            num_processes=2, local_mode=False, use_gpu=False, spark=self.spark
+        ).run(train_fn, 0.001)
+        self.assertEqual(output, "success")
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchWrapperUnitTests(unittest.TestCase):
+    def test_clean_and_terminate(self) -> None:
+        def kill_task(task: "subprocess.Popen") -> None:
+            time.sleep(1)
+            clean_and_terminate(task)
+
+        command = [sys.executable, "-c", '"import time; time.sleep(20)"']
+        task = subprocess.Popen(command)
+        t = threading.Thread(target=kill_task, args=(task,))
+        t.start()
+        time.sleep(2)
+        self.assertEqual(task.poll(), 0)  # implies task ended
+
+    @patch("pyspark.ml.torch.torch_run_process_wrapper.clean_and_terminate")
+    def test_check_parent_alive(self, mock_clean_and_terminate: Callable) -> None:
+        command = [sys.executable, "-c", '"import time; time.sleep(2)"']
+        task = subprocess.Popen(command)
+        t = threading.Thread(target=check_parent_alive, args=(task,), daemon=True)
+        t.start()
+        time.sleep(2)
+        self.assertEqual(mock_clean_and_terminate.call_count, 0)

Review Comment:
   Q:
   
   Why we need to copy the testing code ? Can we reuse original test code like DataFrame connect test suite ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "HyukjinKwon (via GitHub)" <gi...@apache.org>.
HyukjinKwon commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155403791


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -176,20 +184,19 @@ def _get_num_tasks(self) -> int:
         RuntimeError
             Raised when the SparkConf was misconfigured.
         """
-
         if self.use_gpu:
             if not self.local_mode:
                 key = "spark.task.resource.gpu.amount"
-                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                task_gpu_amount = int(_get_conf(self.spark, key, "0"))
                 if task_gpu_amount < 1:
                     raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
                 # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
                 return math.ceil(self.num_processes / task_gpu_amount)
             else:
                 key = "spark.driver.resource.gpu.amount"
-                if "gpu" not in self.sc.resources:
+                if "gpu" not in self.spark.sparkContext.resources:

Review Comment:
   How does this work in Spark Connect?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1156990769


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -578,19 +600,23 @@ def _run_distributed_training(
         )
         self._check_encryption()
         self.logger.info(
-            f"Started distributed training with {self.num_processes} executor proceses"
+            f"Started distributed training with {self.num_processes} executor processes"
         )
         try:
-            result = (
-                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-                .barrier()
-                .mapPartitions(spark_task_function)
-                .collect()[0]
+            rows = (
+                self.spark.range(start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks)
+                .mapInPandas(func=spark_task_function, schema="chunk binary", barrier=True)
+                .collect()
             )
+            output_bytes = b""
+            for row in rows:
+                output_bytes += row.chunk

Review Comment:
   We can skip this for now. Use `+` instead and mark a TODO here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on PR #40607:
URL: https://github.com/apache/spark/pull/40607#issuecomment-1497315263

   I am hitting a weird failure of `TorchDistributorDistributedUnitTestsOnConnect.test_parity_torch_distributor`, it appeared after I rebase this PR yesterday, but I don't find any suspicious commits merged recently.
   
   ```
   ======================================================================
   ERROR [18.362s]: test_end_to_end_run_distributedly (pyspark.ml.tests.connect.test_parity_torch_distributor.TorchDistributorDistributedUnitTestsOnConnect)
   ----------------------------------------------------------------------
   Traceback (most recent call last):
     File "/__w/spark/spark/python/pyspark/ml/torch/tests/test_distributor.py", line 457, in test_end_to_end_run_distributedly
       output = TorchDistributor(num_processes=2, local_mode=False, use_gpu=False).run(
     File "/__w/spark/spark/python/pyspark/ml/torch/distributor.py", line 749, in run
       output = self._run_distributed_training(framework_wrapper_fn, train_object, *args)
     File "/__w/spark/spark/python/pyspark/ml/torch/distributor.py", line 607, in _run_distributed_training
       self.spark.range(start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks)
     File "/__w/spark/spark/python/pyspark/sql/connect/dataframe.py", line 1354, in collect
       table, schema = self._session.client.to_table(query)
     File "/__w/spark/spark/python/pyspark/sql/connect/client.py", line 668, in to_table
       table, schema, _, _, _ = self._execute_and_fetch(req)
     File "/__w/spark/spark/python/pyspark/sql/connect/client.py", line 982, in _execute_and_fetch
       for response in self._execute_and_fetch_as_iterator(req):
     File "/__w/spark/spark/python/pyspark/sql/connect/client.py", line 963, in _execute_and_fetch_as_iterator
       self._handle_error(error)
     File "/__w/spark/spark/python/pyspark/sql/connect/client.py", line 1055, in _handle_error
       self._handle_rpc_error(error)
     File "/__w/spark/spark/python/pyspark/sql/connect/client.py", line 1095, in _handle_rpc_error
       raise SparkConnectGrpcException(str(rpc_error)) from None
   pyspark.errors.exceptions.connect.SparkConnectGrpcException: <_MultiThreadedRendezvous of RPC that terminated with:
   	status = StatusCode.UNKNOWN
   	details = "Java heap space"
   	debug_error_string = "UNKNOWN:Error received from peer ipv4:127.0.0.1:35071 {created_time:"2023-04-05T10:13:52.254507275+00:00", grpc_status:2, grpc_message:"Java heap space"}"
   >
   ```
   
   
   In my local env, I can only repro this by decreasing the driver memory (e.g. "spark.driver.memory", "512M"), And this issue can be simply resolve by increase the driver memory to 1024M.
   I tests different combinations locally like:
   `spark.driver.memory=1024M, spark.executor.memory=512M`
   `spark.driver.memory=1024M, spark.executor.memory=1024M`
   etc
   and they also works as expected.
   
   But in Github Action (this resource limitation seems to be https://github.com/apache/spark/blob/0b45a5278026c2ea9ce2b127333514f7a7a933f4/.github/workflows/build_and_test.yml#L1028), no matter how larger driver memory I set (3G, 4G), this test just keeps failing with this error message.
   
   Do you have any thoughts on this? @WeichenXu123 @HyukjinKwon 
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on PR #40607:
URL: https://github.com/apache/spark/pull/40607#issuecomment-1498386497

   @HyukjinKwon WDYT ? Can we increase CI bot memory capacity ?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng closed pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng closed pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect
URL: https://github.com/apache/spark/pull/40607


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1159434752


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -52,13 +52,18 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
     session.withActive {
 
       // Add debug information to the query execution so that the jobs are traceable.
-      val debugString = v.toString
-      session.sparkContext.setLocalProperty(
-        "callSite.short",
-        s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}")
-      session.sparkContext.setLocalProperty(
-        "callSite.long",
-        StringUtils.abbreviate(debugString, 2048))
+      try {

Review Comment:
   ```
   Started distributed training with 2 executor processes
   java.lang.OutOfMemoryError: Java heap space
   	at java.util.Arrays.copyOfRange(Arrays.java:3664)
   	at java.lang.String.<init>(String.java:207)
   	at java.lang.StringBuilder.toString(StringBuilder.java:407)
   	at org.sparkproject.connect.protobuf.TextFormatEscaper.escapeBytes(TextFormatEscaper.java:112)
   	at org.sparkproject.connect.protobuf.TextFormatEscaper.escapeBytes(TextFormatEscaper.java:119)
   	at org.sparkproject.connect.protobuf.TextFormat.escapeBytes(TextFormat.java:2364)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:593)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   Extracting /home/ruifeng.zheng/spark/python/target/50b3d81f-67a9-46de-9beb-59b733c16e54/tmp72g322ys/MNIST/raw/t10k-labels-idx1-ubyte.gz to /home/ruifeng.zheng/spark/python/target/50b3d81f-67a9-46de-9beb-59b733c16e54/tmp72g322ys/MNIST/raw
   ```
   
   `v.toString` is keeping OOM in `TorchDistributorDistributedUnitTestsOnConnect`.
   This OOM is related to OS or Java Version, it was thrown in Linux+Java8, but doesn't emerge in my local env (macos+java11).
   
   
   The GA resources for free usage is limited to 2U 6G (confirmed with @Yikun), and I believe we cannot allocate enough driver memory for distributed pytorch training without this fix.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1156795362


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -578,19 +600,23 @@ def _run_distributed_training(
         )
         self._check_encryption()
         self.logger.info(
-            f"Started distributed training with {self.num_processes} executor proceses"
+            f"Started distributed training with {self.num_processes} executor processes"
         )
         try:
-            result = (
-                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-                .barrier()
-                .mapPartitions(spark_task_function)
-                .collect()[0]
+            rows = (
+                self.spark.range(start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks)
+                .mapInPandas(func=spark_task_function, schema="chunk binary", barrier=True)
+                .collect()
             )
+            output_bytes = b""
+            for row in rows:
+                output_bytes += row.chunk

Review Comment:
   debuging, I suspect `join` will consume more memory



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on PR #40607:
URL: https://github.com/apache/spark/pull/40607#issuecomment-1496734980

   In my local env, the failed test can pass with even bigger model size.
   but let me try to reduce the model size for GA to see what will happen.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1156934790


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -578,19 +600,23 @@ def _run_distributed_training(
         )
         self._check_encryption()
         self.logger.info(
-            f"Started distributed training with {self.num_processes} executor proceses"
+            f"Started distributed training with {self.num_processes} executor processes"
         )
         try:
-            result = (
-                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-                .barrier()
-                .mapPartitions(spark_task_function)
-                .collect()[0]
+            rows = (
+                self.spark.range(start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks)
+                .mapInPandas(func=spark_task_function, schema="chunk binary", barrier=True)
+                .collect()
             )
+            output_bytes = b""
+            for row in rows:
+                output_bytes += row.chunk

Review Comment:
   But your test only runs with small dataset.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jaceklaskowski commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "jaceklaskowski (via GitHub)" <gi...@apache.org>.
jaceklaskowski commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155657595


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -263,6 +265,10 @@ class TorchDistributor(Distributor):
 
     .. versionadded:: 3.4.0
 
+    .. versionchanged:: 3.5.0
+        Supports Spark Connect. Note local model with GPU is not supported yet, will be fixed

Review Comment:
   nit: Note that...
   
   Also, "local model" or "local mode"?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155434477


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -32,54 +32,65 @@
 
 from pyspark import cloudpickle
 from pyspark.sql import SparkSession
+from pyspark.taskcontext import BarrierTaskContext
 from pyspark.ml.torch.log_communication import (  # type: ignore
-    get_driver_host,
     LogStreamingClient,
     LogStreamingServer,
 )
-from pyspark.context import SparkContext
-from pyspark.taskcontext import BarrierTaskContext
 
 
-# TODO(SPARK-41589): will move the functions and tests to an external file
-#       once we are in agreement about which functions should be in utils.py
-def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
-    """Get the conf "key" from the given spark context,
+def _get_active_session() -> SparkSession:
+    from pyspark.sql.utils import is_remote
+
+    if not is_remote():
+        spark = SparkSession.getActiveSession()
+    else:
+        from pyspark.sql.connect.session import _active_spark_session
+
+        spark = _active_spark_session  # type: ignore[assignment]
+
+    if spark is None:
+        raise RuntimeError("An active SparkSession is required for the distributor.")
+    return spark
+
+
+def _get_conf(spark: SparkSession, key: str, default_value: str) -> str:
+    """Get the conf "key" from the given spark session,
     or return the default value if the conf is not set.
-    This expects the conf value to be a boolean or string;
-    if the value is a string, this checks for all capitalization
-    patterns of "true" and "false" to match Scala.
+    If this session is a remote connect session, the SparkConf
+    code path always fail, and fallback to the RuntimeConf.
 
     Parameters
     ----------
-    sc : :class:`SparkContext`
-        The :class:`SparkContext` for the distributor.
+    spark : :class:`SparkSession`
+        The :class:`SparkSession` for the distributor.
     key : str
         string for conf name
     default_value : str
         default value for the conf value for the given key
 
     Returns
     -------
-    bool
-        Returns the boolean value that corresponds to the conf
-
-    Raises
-    ------
-    ValueError
-        Thrown when the conf value is not a valid boolean
+    str
+        Returns the string value that corresponds to the conf
     """
-    val = sc.getConf().get(key, default_value)
-    lowercase_val = val.lower()
-    if lowercase_val == "true":
-        return True
-    if lowercase_val == "false":
-        return False
-    raise ValueError(
-        f"The conf value for '{key}' was expected to be a boolean "
-        f"value but found value of type {type(val)} "
-        f"with value: {val}"
-    )
+    from pyspark.sql.utils import is_remote
+
+    if not is_remote():
+        value = spark.sparkContext.getConf().get(key, default_value)

Review Comment:
   Why not updates this part code "change the cached conf sc._conf" ? We shouldn't change internal `sc._conf` attributes.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [WIP][ML] Make Torch Distributor support Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1153071026


##########
python/pyspark/ml/tests/connect/test_parity_torch_distributor.py:
##########
@@ -0,0 +1,511 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import contextlib
+import os
+import shutil
+from six import StringIO
+import stat
+import subprocess
+import sys
+import time
+import tempfile
+import threading
+from typing import Callable, Dict, Any
+import unittest
+from unittest.mock import patch
+
+have_torch = True
+try:
+    import torch  # noqa: F401
+except ImportError:
+    have_torch = False
+
+from pyspark import SparkConf
+from pyspark.ml.torch.distributor import TorchDistributor
+from pyspark.ml.torch.torch_run_process_wrapper import clean_and_terminate, check_parent_alive
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+
+
+@contextlib.contextmanager
+def patch_stdout() -> StringIO:
+    """patch stdout and give an output"""
+    sys_stdout = sys.stdout
+    io_out = StringIO()
+    sys.stdout = io_out
+    try:
+        yield io_out
+    finally:
+        sys.stdout = sys_stdout
+
+
+def create_training_function(mnist_dir_path: str) -> Callable:
+    import torch.nn as nn
+    import torch.nn.functional as F
+    from torchvision import transforms, datasets
+
+    batch_size = 100
+    num_epochs = 1
+    momentum = 0.5
+
+    train_dataset = datasets.MNIST(
+        mnist_dir_path,
+        train=True,
+        download=True,
+        transform=transforms.Compose(
+            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
+        ),
+    )
+
+    class Net(nn.Module):
+        def __init__(self) -> None:
+            super(Net, self).__init__()
+            self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
+            self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
+            self.conv2_drop = nn.Dropout2d()
+            self.fc1 = nn.Linear(320, 50)
+            self.fc2 = nn.Linear(50, 10)
+
+        def forward(self, x: Any) -> Any:
+            x = F.relu(F.max_pool2d(self.conv1(x), 2))
+            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
+            x = x.view(-1, 320)
+            x = F.relu(self.fc1(x))
+            x = F.dropout(x, training=self.training)
+            x = self.fc2(x)
+            return F.log_softmax(x)
+
+    def train_fn(learning_rate: float) -> Any:
+        import torch
+        import torch.optim as optim
+        import torch.distributed as dist
+        from torch.nn.parallel import DistributedDataParallel as DDP
+        from torch.utils.data.distributed import DistributedSampler
+
+        dist.init_process_group("gloo")
+
+        train_sampler = DistributedSampler(dataset=train_dataset)
+        data_loader = torch.utils.data.DataLoader(
+            train_dataset, batch_size=batch_size, sampler=train_sampler
+        )
+
+        model = Net()
+        ddp_model = DDP(model)
+        optimizer = optim.SGD(ddp_model.parameters(), lr=learning_rate, momentum=momentum)
+        for epoch in range(1, num_epochs + 1):
+            model.train()
+            for _, (data, target) in enumerate(data_loader):
+                optimizer.zero_grad()
+                output = model(data)
+                loss = F.nll_loss(output, target)
+                loss.backward()
+                optimizer.step()
+            print(f"epoch {epoch} finished.")
+
+        return "success"
+
+    return train_fn
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorBaselineUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.remote("local[4]").getOrCreate()
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+    def setup_env_vars(self, input_map: Dict[str, str]) -> None:
+        for key, value in input_map.items():
+            os.environ[key] = value
+
+    def delete_env_vars(self, input_map: Dict[str, str]) -> None:
+        for key in input_map.keys():
+            del os.environ[key]
+
+    def test_validate_correct_inputs(self) -> None:
+        inputs = [
+            (1, True, False),
+            (100, True, False),
+            (1, False, False),
+            (100, False, False),
+        ]
+        for num_processes, local_mode, use_gpu in inputs:
+            with self.subTest():
+                expected_params = {
+                    "num_processes": num_processes,
+                    "local_mode": local_mode,
+                    "use_gpu": use_gpu,
+                    "num_tasks": num_processes,
+                }
+                dist = TorchDistributor(num_processes, local_mode, use_gpu, spark=self.spark)
+                self.assertEqual(expected_params, dist.input_params)
+
+    def test_validate_incorrect_inputs(self) -> None:
+        inputs = [
+            (0, False, False, ValueError, "positive"),
+        ]
+        for num_processes, local_mode, use_gpu, error, message in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(error, message):
+                    TorchDistributor(num_processes, local_mode, use_gpu)
+
+    # TODO: Should support read SparkConf and initialize Remote Session with SSL options
+    # def test_encryption_passes(self) -> None:
+    #     inputs = [
+    #         ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "true"),
+    #         ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "false"),
+    #         ("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "true"),
+    #     ]
+    #     for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
+    #         with self.subTest():
+    #             self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+    #             self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+    #             distributor = TorchDistributor(1, True, False)
+    #             distributor._check_encryption()
+
+    # def test_encryption_fails(self) -> None:
+    #     # this is the only combination that should fail
+    #     inputs = [("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "false")]
+    #     for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
+    #         with self.subTest():
+    #             with self.assertRaisesRegex(Exception, "encryption"):
+    #                 self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+    #                 self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+    #                 distributor = TorchDistributor(1, True, False)
+    #                 distributor._check_encryption()
+
+    def test_get_num_tasks_fails(self) -> None:
+        inputs = [1, 5, 4]
+
+        # This is when the conf isn't set and we request GPUs
+        for num_processes in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(RuntimeError, "driver"):
+                    TorchDistributor(num_processes, True, True)
+                with self.assertRaisesRegex(RuntimeError, "unset"):
+                    TorchDistributor(num_processes, False, True)
+
+    def test_execute_command(self) -> None:
+        """Test that run command runs the process and logs are written correctly"""
+
+        with patch_stdout() as output:
+            stdout_command = ["echo", "hello_stdout"]
+            TorchDistributor._execute_command(stdout_command)
+            self.assertIn(
+                "hello_stdout", output.getvalue().strip(), "hello_stdout should print to stdout"
+            )
+
+        with patch_stdout() as output:
+            stderr_command = ["bash", "-c", "echo hello_stderr >&2"]
+            TorchDistributor._execute_command(stderr_command)
+            self.assertIn(
+                "hello_stderr", output.getvalue().strip(), "hello_stderr should print to stdout"
+            )
+
+        # include command in the exception message
+        with self.assertRaisesRegex(RuntimeError, "exit 1"):
+            error_command = ["bash", "-c", "exit 1"]
+            TorchDistributor._execute_command(error_command)
+
+        with self.assertRaisesRegex(RuntimeError, "abcdef"):
+            error_command = ["bash", "-c", "'abc''def'"]
+            TorchDistributor._execute_command(error_command)
+
+    def test_create_torchrun_command(self) -> None:
+        train_path = "train.py"
+        args_string = ["1", "3"]
+        local_mode_input_params = {"num_processes": 4, "local_mode": True}
+
+        expected_local_mode_output = [
+            sys.executable,
+            "-m",
+            "pyspark.ml.torch.torch_run_process_wrapper",
+            "--standalone",
+            "--nnodes=1",
+            "--nproc_per_node=4",
+            "train.py",
+            "1",
+            "3",
+        ]
+        self.assertEqual(
+            TorchDistributor._create_torchrun_command(
+                local_mode_input_params, train_path, *args_string
+            ),
+            expected_local_mode_output,
+        )
+
+        distributed_mode_input_params = {"num_processes": 4, "local_mode": False}
+        input_env_vars = {"MASTER_ADDR": "localhost", "MASTER_PORT": "9350", "RANK": "3"}
+
+        args_number = [1, 3]  # testing conversion to strings
+        self.setup_env_vars(input_env_vars)
+        expected_distributed_mode_output = [
+            sys.executable,
+            "-m",
+            "pyspark.ml.torch.torch_run_process_wrapper",
+            "--nnodes=4",
+            "--node_rank=3",
+            "--rdzv_endpoint=localhost:9350",
+            "--rdzv_id=0",
+            "--nproc_per_node=1",
+            "train.py",
+            "1",
+            "3",
+        ]
+        self.assertEqual(
+            TorchDistributor._create_torchrun_command(
+                distributed_mode_input_params, train_path, *args_number
+            ),
+            expected_distributed_mode_output,
+        )
+        self.delete_env_vars(input_env_vars)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorLocalUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.gpu_discovery_script_file = tempfile.NamedTemporaryFile(delete=False)
+        self.gpu_discovery_script_file.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.gpu_discovery_script_file.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.gpu_discovery_script_file.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+
+        self.spark = (
+            SparkSession.builder.appName(class_name)
+            .config("spark.test.home", SPARK_HOME)
+            .config("spark.driver.resource.gpu.amount", "3")
+            .config(
+                "spark.driver.resource.gpu.discoveryScript", self.gpu_discovery_script_file.name
+            )
+            .remote("local-cluster[2,2,1024]")
+            .getOrCreate()
+        )
+
+        self.mnist_dir_path = tempfile.mkdtemp()
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.mnist_dir_path)
+        os.unlink(self.gpu_discovery_script_file.name)
+        self.spark.stop()
+
+    def setup_env_vars(self, input_map: Dict[str, str]) -> None:
+        for key, value in input_map.items():
+            os.environ[key] = value
+
+    def delete_env_vars(self, input_map: Dict[str, str]) -> None:
+        for key in input_map.keys():
+            del os.environ[key]
+
+    def test_get_num_tasks_locally(self) -> None:
+        succeeds = [1, 2]
+        fails = [4, 8]
+        for num_processes in succeeds:
+            with self.subTest():
+                expected_output = num_processes
+                distributor = TorchDistributor(num_processes, True, True, spark=self.spark)
+                self.assertEqual(distributor._get_num_tasks(), expected_output)
+
+        for num_processes in fails:
+            with self.subTest():
+                with self.assertLogs("TorchDistributor", level="WARNING") as log:
+                    distributor = TorchDistributor(num_processes, True, True, spark=self.spark)
+                    self.assertEqual(len(log.records), 1)
+                    self.assertEqual(distributor.num_processes, 3)
+
+    # TODO: support SparkConf
+    # def test_get_gpus_owned_local(self) -> None:
+    #     addresses = ["0", "1", "2"]
+    #     self.assertEqual(get_gpus_owned(self.sc), addresses)
+    #
+    #     env_vars = {"CUDA_VISIBLE_DEVICES": "3,4,5"}
+    #     self.setup_env_vars(env_vars)
+    #     self.assertEqual(get_gpus_owned(self.sc), ["3", "4", "5"])
+    #     self.delete_env_vars(env_vars)
+
+    def test_local_training_succeeds(self) -> None:
+        CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
+        inputs = [
+            ("0,1,2", 1, True, "1"),
+            ("0,1,2", 3, True, "1,2,0"),
+            ("0,1,2", 2, False, "0,1,2"),
+            (None, 3, False, "NONE"),
+        ]
+
+        for i, (cuda_env_var, num_processes, use_gpu, expected) in enumerate(inputs):
+            with self.subTest(f"subtest: {i + 1}"):
+                # setup
+                if cuda_env_var:
+                    self.setup_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var})
+
+                dist = TorchDistributor(num_processes, True, use_gpu, spark=self.spark)
+                dist._run_training_on_pytorch_file = lambda *args: os.environ.get(
+                    CUDA_VISIBLE_DEVICES, "NONE"
+                )
+                self.assertEqual(
+                    expected,
+                    dist._run_local_training(dist._run_training_on_pytorch_file, "train.py"),
+                )
+                # cleanup
+                if cuda_env_var:
+                    self.delete_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var})
+
+    def test_local_file_with_pytorch(self) -> None:
+        spark = SparkSession.builder.remote("local[4]").getOrCreate()
+        test_file_path = "python/test_support/test_pytorch_training_file.py"
+        learning_rate_str = "0.01"
+        TorchDistributor(num_processes=2, local_mode=True, use_gpu=False, spark=spark).run(
+            test_file_path, learning_rate_str
+        )
+
+    def test_end_to_end_run_locally(self) -> None:
+        spark = SparkSession.builder.remote("local[4]").getOrCreate()
+        train_fn = create_training_function(self.mnist_dir_path)
+        output = TorchDistributor(num_processes=2, local_mode=True, use_gpu=False, spark=spark).run(
+            train_fn, 0.001
+        )
+        self.assertEqual(output, "success")
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorDistributedUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.gpu_discovery_script_file = tempfile.NamedTemporaryFile(delete=False)
+        self.gpu_discovery_script_file.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.gpu_discovery_script_file.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.gpu_discovery_script_file.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+        conf = SparkConf().set("spark.test.home", SPARK_HOME)
+
+        conf = conf.set(
+            "spark.worker.resource.gpu.discoveryScript", self.gpu_discovery_script_file.name
+        )
+        conf = conf.set("spark.worker.resource.gpu.amount", "3")
+        conf = conf.set("spark.task.cpus", "2")
+        conf = conf.set("spark.task.resource.gpu.amount", "1")
+        conf = conf.set("spark.executor.resource.gpu.amount", "1")
+
+        self.spark = (
+            SparkSession.builder.appName(class_name)
+            .config(conf=conf)
+            .remote("local-cluster[2,2,1024]")
+            .getOrCreate()
+        )
+        self.mnist_dir_path = tempfile.mkdtemp()
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.mnist_dir_path)
+        os.unlink(self.gpu_discovery_script_file.name)
+        self.spark.stop()
+
+    def test_dist_training_succeeds(self) -> None:
+        CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
+        inputs = [
+            ("0,1,2", 2, True, "0"),
+        ]
+
+        for i, (_, num_processes, use_gpu, expected) in enumerate(inputs):
+            with self.subTest(f"subtest: {i + 1}"):
+                dist = TorchDistributor(num_processes, False, use_gpu, spark=self.spark)
+                dist._run_training_on_pytorch_file = lambda *args: os.environ.get(
+                    CUDA_VISIBLE_DEVICES, "NONE"
+                )
+                self.assertEqual(
+                    expected,
+                    dist._run_distributed_training(dist._run_training_on_pytorch_file, "..."),
+                )
+
+    # Can not set SparkConf
+    # def test_get_num_tasks_distributed(self) -> None:
+    #     inputs = [(1, 8, 8), (2, 8, 4), (3, 8, 3)]
+    #
+    #     for spark_conf_value, num_processes, expected_output in inputs:
+    #         with self.subTest():
+    #             self.spark.sparkContext._conf.set(
+    #                 "spark.task.resource.gpu.amount", str(spark_conf_value)
+    #             )
+    #             distributor = TorchDistributor(num_processes, False, True, self.spark)
+    #             self.assertEqual(distributor._get_num_tasks(), expected_output)
+    #
+    #     self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount", "1")
+
+    def test_distributed_file_with_pytorch(self) -> None:
+        test_file_path = "python/test_support/test_pytorch_training_file.py"
+        learning_rate_str = "0.01"
+        TorchDistributor(num_processes=2, local_mode=False, use_gpu=False, spark=self.spark).run(
+            test_file_path, learning_rate_str
+        )
+
+    def test_end_to_end_run_distributedly(self) -> None:
+        train_fn = create_training_function(self.mnist_dir_path)
+        output = TorchDistributor(
+            num_processes=2, local_mode=False, use_gpu=False, spark=self.spark
+        ).run(train_fn, 0.001)
+        self.assertEqual(output, "success")
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchWrapperUnitTests(unittest.TestCase):
+    def test_clean_and_terminate(self) -> None:
+        def kill_task(task: "subprocess.Popen") -> None:
+            time.sleep(1)
+            clean_and_terminate(task)
+
+        command = [sys.executable, "-c", '"import time; time.sleep(20)"']
+        task = subprocess.Popen(command)
+        t = threading.Thread(target=kill_task, args=(task,))
+        t.start()
+        time.sleep(2)
+        self.assertEqual(task.poll(), 0)  # implies task ended
+
+    @patch("pyspark.ml.torch.torch_run_process_wrapper.clean_and_terminate")
+    def test_check_parent_alive(self, mock_clean_and_terminate: Callable) -> None:
+        command = [sys.executable, "-c", '"import time; time.sleep(2)"']
+        task = subprocess.Popen(command)
+        t = threading.Thread(target=check_parent_alive, args=(task,), daemon=True)
+        t.start()
+        time.sleep(2)
+        self.assertEqual(mock_clean_and_terminate.call_count, 0)

Review Comment:
   we can reuse it for sure, right now I just copy it for convinience.
   when this pr is ready, it will reuse existing ut 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155430021


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -32,54 +32,65 @@
 
 from pyspark import cloudpickle
 from pyspark.sql import SparkSession
+from pyspark.taskcontext import BarrierTaskContext
 from pyspark.ml.torch.log_communication import (  # type: ignore
-    get_driver_host,
     LogStreamingClient,
     LogStreamingServer,
 )
-from pyspark.context import SparkContext
-from pyspark.taskcontext import BarrierTaskContext
 
 
-# TODO(SPARK-41589): will move the functions and tests to an external file
-#       once we are in agreement about which functions should be in utils.py
-def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
-    """Get the conf "key" from the given spark context,
+def _get_active_session() -> SparkSession:
+    from pyspark.sql.utils import is_remote
+
+    if not is_remote():
+        spark = SparkSession.getActiveSession()
+    else:
+        from pyspark.sql.connect.session import _active_spark_session
+
+        spark = _active_spark_session  # type: ignore[assignment]
+
+    if spark is None:
+        raise RuntimeError("An active SparkSession is required for the distributor.")
+    return spark
+
+
+def _get_conf(spark: SparkSession, key: str, default_value: str) -> str:
+    """Get the conf "key" from the given spark session,
     or return the default value if the conf is not set.
-    This expects the conf value to be a boolean or string;
-    if the value is a string, this checks for all capitalization
-    patterns of "true" and "false" to match Scala.
+    If this session is a remote connect session, the SparkConf
+    code path always fail, and fallback to the RuntimeConf.
 
     Parameters
     ----------
-    sc : :class:`SparkContext`
-        The :class:`SparkContext` for the distributor.
+    spark : :class:`SparkSession`
+        The :class:`SparkSession` for the distributor.
     key : str
         string for conf name
     default_value : str
         default value for the conf value for the given key
 
     Returns
     -------
-    bool
-        Returns the boolean value that corresponds to the conf
-
-    Raises
-    ------
-    ValueError
-        Thrown when the conf value is not a valid boolean
+    str
+        Returns the string value that corresponds to the conf
     """
-    val = sc.getConf().get(key, default_value)
-    lowercase_val = val.lower()
-    if lowercase_val == "true":
-        return True
-    if lowercase_val == "false":
-        return False
-    raise ValueError(
-        f"The conf value for '{key}' was expected to be a boolean "
-        f"value but found value of type {type(val)} "
-        f"with value: {val}"
-    )
+    from pyspark.sql.utils import is_remote
+
+    if not is_remote():
+        value = spark.sparkContext.getConf().get(key, default_value)

Review Comment:
   Q: Why not using `spark.conf.get` too ? So that we don't need 2 branches here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [WIP][ML] Make Torch Distributor support Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1153066202


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -581,11 +593,11 @@ def _run_distributed_training(
             f"Started distributed training with {self.num_processes} executor proceses"
         )
         try:
+            assert self.spark is not None
             result = (
-                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-                .barrier()
-                .mapPartitions(spark_task_function)
-                .collect()[0]
+                self.spark.range(start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks)
+                .mapInPandas(func=spark_task_function, schema="output binary", barrier=True)
+                .first()["output"]

Review Comment:
   I think barrier mode task does not support `.first()` operation if I remember it correctly



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155433317


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -554,7 +568,7 @@ def set_gpus(context: "BarrierTaskContext") -> None:
                     pass
 
             if context.partitionId() == 0:
-                yield output
+                yield pd.DataFrame(data={"output": [cloudpickle.dumps(output)]})

Review Comment:
   since the output dataframe only contains one row, I think `DataFrame.collect` can not properly handle it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "HyukjinKwon (via GitHub)" <gi...@apache.org>.
HyukjinKwon commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155403938


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -176,20 +184,19 @@ def _get_num_tasks(self) -> int:
         RuntimeError
             Raised when the SparkConf was misconfigured.
         """
-
         if self.use_gpu:
             if not self.local_mode:
                 key = "spark.task.resource.gpu.amount"
-                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                task_gpu_amount = int(_get_conf(self.spark, key, "0"))
                 if task_gpu_amount < 1:
                     raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
                 # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
                 return math.ceil(self.num_processes / task_gpu_amount)
             else:
                 key = "spark.driver.resource.gpu.amount"
-                if "gpu" not in self.sc.resources:
+                if "gpu" not in self.spark.sparkContext.resources:

Review Comment:
   Should probably at least comment here that it doesn't work with Spark connect (or add a todo with a JIRA)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155407189


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -176,20 +184,19 @@ def _get_num_tasks(self) -> int:
         RuntimeError
             Raised when the SparkConf was misconfigured.
         """
-
         if self.use_gpu:
             if not self.local_mode:
                 key = "spark.task.resource.gpu.amount"
-                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                task_gpu_amount = int(_get_conf(self.spark, key, "0"))
                 if task_gpu_amount < 1:
                     raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
                 # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
                 return math.ceil(self.num_processes / task_gpu_amount)
             else:
                 key = "spark.driver.resource.gpu.amount"
-                if "gpu" not in self.sc.resources:
+                if "gpu" not in self.spark.sparkContext.resources:

Review Comment:
   thanks, will add the comment



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155432023


##########
python/pyspark/ml/tests/connect/test_parity_torch_distributor.py:
##########
@@ -0,0 +1,183 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import stat
+import tempfile
+import unittest
+
+have_torch = True
+try:
+    import torch  # noqa: F401
+except ImportError:
+    have_torch = False
+
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+from pyspark.ml.torch.distributor import TorchDistributor
+
+from pyspark.ml.torch.tests.test_distributor import (
+    TorchDistributorBaselineUnitTestsMixin,
+    TorchDistributorLocalUnitTestsMixin,
+    TorchDistributorDistributedUnitTestsMixin,
+    TorchWrapperUnitTestsMixin,
+)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorBaselineUnitTestsOnConnect(
+    TorchDistributorBaselineUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.remote("local[4]").getOrCreate()
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+    @unittest.skip("Can not dynamically set ssl conf via spark.sparkContext._conf.")
+    def test_encryption_passes(self):
+        super().test_encryption_passes()
+
+    @unittest.skip("Can not dynamically set ssl conf via spark.sparkContext._conf.")
+    def test_encryption_fails(self):
+        super().test_encryption_fails()
+
+    def test_get_num_tasks_fails(self) -> None:
+        inputs = [1, 5, 4]
+
+        # This is when the conf isn't set and we request GPUs
+        for num_processes in inputs:
+            with self.subTest():
+                # TODO(SPARK-42994): Support sc.resources
+                # with self.assertRaisesRegex(RuntimeError, "driver"):
+                #     TorchDistributor(num_processes, True, True)
+                with self.assertRaisesRegex(RuntimeError, "unset"):
+                    TorchDistributor(num_processes, False, True)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorLocalUnitTestsOnConnect(
+    TorchDistributorLocalUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.gpu_discovery_script_file = tempfile.NamedTemporaryFile(delete=False)
+        self.gpu_discovery_script_file.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.gpu_discovery_script_file.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.gpu_discovery_script_file.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+
+        self.spark = (
+            SparkSession.builder.appName(class_name)
+            .config("spark.test.home", SPARK_HOME)
+            .config("spark.driver.resource.gpu.amount", "3")
+            .config(
+                "spark.driver.resource.gpu.discoveryScript", self.gpu_discovery_script_file.name
+            )
+            .remote("local-cluster[2,2,1024]")
+            .getOrCreate()
+        )
+
+        self.mnist_dir_path = tempfile.mkdtemp()
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.mnist_dir_path)
+        os.unlink(self.gpu_discovery_script_file.name)
+        self.spark.stop()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_num_tasks_locally(self):
+        super().test_get_num_tasks_locally()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_gpus_owned_local(self):
+        super().test_get_gpus_owned_local()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_local_training_succeeds(self):
+        super().test_local_training_succeeds()
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorDistributedUnitTestsOnConnect(
+    TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.gpu_discovery_script_file = tempfile.NamedTemporaryFile(delete=False)
+        self.gpu_discovery_script_file.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.gpu_discovery_script_file.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.gpu_discovery_script_file.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+        self.spark = (
+            SparkSession.builder.appName(class_name)
+            .config("spark.test.home", SPARK_HOME)
+            .config(
+                "spark.worker.resource.gpu.discoveryScript", self.gpu_discovery_script_file.name
+            )
+            .config("spark.worker.resource.gpu.amount", "3")
+            .config("spark.task.cpus", "2")
+            .config("spark.task.resource.gpu.amount", "1")
+            .config("spark.executor.resource.gpu.amount", "1")
+            .remote("local-cluster[2,2,1024]")

Review Comment:
   Can we reuse the similar code in spark.ml side (in TorchDistributorDistributedUnitTests)?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155808647


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -263,6 +265,10 @@ class TorchDistributor(Distributor):
 
     .. versionadded:: 3.4.0
 
+    .. versionchanged:: 3.5.0
+        Supports Spark Connect. Note local model with GPU is not supported yet, will be fixed

Review Comment:
   good catch, will fix.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "HyukjinKwon (via GitHub)" <gi...@apache.org>.
HyukjinKwon commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1157112323


##########
python/pyspark/ml/tests/connect/test_parity_torch_distributor.py:
##########
@@ -0,0 +1,132 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import tempfile
+import unittest
+
+have_torch = True
+try:
+    import torch  # noqa: F401
+except ImportError:
+    have_torch = False
+
+from pyspark.sql import SparkSession
+from pyspark.ml.torch.distributor import TorchDistributor
+
+from pyspark.ml.torch.tests.test_distributor import (
+    TorchDistributorBaselineUnitTestsMixin,
+    TorchDistributorLocalUnitTestsMixin,
+    TorchDistributorDistributedUnitTestsMixin,
+    TorchWrapperUnitTestsMixin,
+)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorBaselineUnitTestsOnConnect(
+    TorchDistributorBaselineUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.remote("local[4]").getOrCreate()
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+    def test_get_num_tasks_fails(self) -> None:
+        inputs = [1, 5, 4]
+
+        # This is when the conf isn't set and we request GPUs
+        for num_processes in inputs:
+            with self.subTest():
+                # TODO(SPARK-42994): Support sc.resources
+                # with self.assertRaisesRegex(RuntimeError, "driver"):
+                #     TorchDistributor(num_processes, True, True)
+                with self.assertRaisesRegex(RuntimeError, "unset"):
+                    TorchDistributor(num_processes, False, True)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorLocalUnitTestsOnConnect(
+    TorchDistributorLocalUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        conf = self._get_spark_conf()
+        builder = SparkSession.builder.appName(class_name)
+        for k, v in conf.getAll():
+            if k not in ["spark.master", "spark.remote", "spark.app.name"]:
+                builder = builder.config(k, v)
+        self.spark = builder.remote("local-cluster[2,2,1024]").getOrCreate()
+        self.mnist_dir_path = tempfile.mkdtemp()
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.mnist_dir_path)
+        os.unlink(self.gpu_discovery_script_file.name)
+        self.spark.stop()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_num_tasks_locally(self):
+        super().test_get_num_tasks_locally()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_gpus_owned_local(self):
+        super().test_get_gpus_owned_local()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_local_training_succeeds(self):
+        super().test_local_training_succeeds()
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorDistributedUnitTestsOnConnect(
+    TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        conf = self._get_spark_conf()
+        builder = SparkSession.builder.appName(class_name)
+        for k, v in conf.getAll():
+            if k not in ["spark.master", "spark.remote", "spark.app.name"]:
+                builder = builder.config(k, v)
+        self.spark = builder.remote("local-cluster[2,2,1024]").getOrCreate()

Review Comment:
   They are tested one by one in CI, see https://github.com/apache/spark/blob/master/.github/workflows/build_and_test.yml#L404.
   
   Can you try decreasing the memory here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1159434752


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -52,13 +52,18 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
     session.withActive {
 
       // Add debug information to the query execution so that the jobs are traceable.
-      val debugString = v.toString
-      session.sparkContext.setLocalProperty(
-        "callSite.short",
-        s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}")
-      session.sparkContext.setLocalProperty(
-        "callSite.long",
-        StringUtils.abbreviate(debugString, 2048))
+      try {

Review Comment:
   ```
   Started distributed training with 2 executor processes
   java.lang.OutOfMemoryError: Java heap space
   	at java.util.Arrays.copyOfRange(Arrays.java:3664)
   	at java.lang.String.<init>(String.java:207)
   	at java.lang.StringBuilder.toString(StringBuilder.java:407)
   	at org.sparkproject.connect.protobuf.TextFormatEscaper.escapeBytes(TextFormatEscaper.java:112)
   	at org.sparkproject.connect.protobuf.TextFormatEscaper.escapeBytes(TextFormatEscaper.java:119)
   	at org.sparkproject.connect.protobuf.TextFormat.escapeBytes(TextFormat.java:2364)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:593)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   Extracting /home/ruifeng.zheng/spark/python/target/50b3d81f-67a9-46de-9beb-59b733c16e54/tmp72g322ys/MNIST/raw/t10k-labels-idx1-ubyte.gz to /home/ruifeng.zheng/spark/python/target/50b3d81f-67a9-46de-9beb-59b733c16e54/tmp72g322ys/MNIST/raw
   ```
   
   `v.toString` is keeping throwing OOM in `TorchDistributorDistributedUnitTestsOnConnect`.
   This OOM seems related to Java Version, it was thrown in both Linux+Java8 and MacOS+Java8, but doesn't emerge in MacOS+Java11.
   
   
   The GA resources for free usage is limited to 2U 6G (confirmed with @Yikun), and I believe we cannot allocate enough driver memory for distributed pytorch training without this fix.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155446538


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -554,7 +568,7 @@ def set_gpus(context: "BarrierTaskContext") -> None:
                     pass
 
             if context.partitionId() == 0:
-                yield output
+                yield pd.DataFrame(data={"output": [cloudpickle.dumps(output)]})

Review Comment:
   sounds good, break the row into multi rows in partition 0, and then concat them in driver



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor support Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155039367


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -581,12 +590,12 @@ def _run_distributed_training(
             f"Started distributed training with {self.num_processes} executor proceses"

Review Comment:
   thanks, will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on PR #40607:
URL: https://github.com/apache/spark/pull/40607#issuecomment-1495768111

   ```
   pyspark.errors.exceptions.connect.SparkConnectGrpcException: <_MultiThreadedRendezvous of RPC that terminated with:
   	status = StatusCode.UNKNOWN
   	details = "Java heap space"
   	debug_error_string = "UNKNOWN:Error received from peer ipv4:127.0.0.1:39429 {grpc_message:"Java heap space", grpc_status:2, created_time:"2023-04-04T10:54:38.384319228+00:00"}"
   >
   ```
   it fails again even I switch back to the initial approach, the error is raised in the server side, so should not be related to the way to concat. I guess there is no enough RAM.
   
   cc @HyukjinKwon are those UTs tested in parallel or one by one?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1157114526


##########
python/pyspark/ml/tests/connect/test_parity_torch_distributor.py:
##########
@@ -0,0 +1,132 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import tempfile
+import unittest
+
+have_torch = True
+try:
+    import torch  # noqa: F401
+except ImportError:
+    have_torch = False
+
+from pyspark.sql import SparkSession
+from pyspark.ml.torch.distributor import TorchDistributor
+
+from pyspark.ml.torch.tests.test_distributor import (
+    TorchDistributorBaselineUnitTestsMixin,
+    TorchDistributorLocalUnitTestsMixin,
+    TorchDistributorDistributedUnitTestsMixin,
+    TorchWrapperUnitTestsMixin,
+)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorBaselineUnitTestsOnConnect(
+    TorchDistributorBaselineUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.remote("local[4]").getOrCreate()
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+    def test_get_num_tasks_fails(self) -> None:
+        inputs = [1, 5, 4]
+
+        # This is when the conf isn't set and we request GPUs
+        for num_processes in inputs:
+            with self.subTest():
+                # TODO(SPARK-42994): Support sc.resources
+                # with self.assertRaisesRegex(RuntimeError, "driver"):
+                #     TorchDistributor(num_processes, True, True)
+                with self.assertRaisesRegex(RuntimeError, "unset"):
+                    TorchDistributor(num_processes, False, True)
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorLocalUnitTestsOnConnect(
+    TorchDistributorLocalUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        conf = self._get_spark_conf()
+        builder = SparkSession.builder.appName(class_name)
+        for k, v in conf.getAll():
+            if k not in ["spark.master", "spark.remote", "spark.app.name"]:
+                builder = builder.config(k, v)
+        self.spark = builder.remote("local-cluster[2,2,1024]").getOrCreate()
+        self.mnist_dir_path = tempfile.mkdtemp()
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.mnist_dir_path)
+        os.unlink(self.gpu_discovery_script_file.name)
+        self.spark.stop()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_num_tasks_locally(self):
+        super().test_get_num_tasks_locally()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_gpus_owned_local(self):
+        super().test_get_gpus_owned_local()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_local_training_succeeds(self):
+        super().test_local_training_succeeds()
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorDistributedUnitTestsOnConnect(
+    TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        conf = self._get_spark_conf()
+        builder = SparkSession.builder.appName(class_name)
+        for k, v in conf.getAll():
+            if k not in ["spark.master", "spark.remote", "spark.app.name"]:
+                builder = builder.config(k, v)
+        self.spark = builder.remote("local-cluster[2,2,1024]").getOrCreate()

Review Comment:
   yeah, I am just trying to reduce the executor memory since in distributor the real computation only happens in torch not spark.
   thanks!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on PR #40607:
URL: https://github.com/apache/spark/pull/40607#issuecomment-1495655168

   Followup tasks:
   
   * We should replace `mapInPandas` with `mapInArrow` for better performance
   * For each spark task, we should save each partition data into local disk (in arrow format) before starting torch process, and provide utility reading methods that torch program (running as a child process) can invoke it to iterate the partition data inside torch program.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on PR #40607:
URL: https://github.com/apache/spark/pull/40607#issuecomment-1495858338

   the GA status is not shown, https://github.com/zhengruifeng/spark/actions/runs/4607449033


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1156596408


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -578,19 +600,23 @@ def _run_distributed_training(
         )
         self._check_encryption()
         self.logger.info(
-            f"Started distributed training with {self.num_processes} executor proceses"
+            f"Started distributed training with {self.num_processes} executor processes"
         )
         try:
-            result = (
-                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-                .barrier()
-                .mapPartitions(spark_task_function)
-                .collect()[0]
+            rows = (
+                self.spark.range(start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks)
+                .mapInPandas(func=spark_task_function, schema="chunk binary", barrier=True)
+                .collect()
             )
+            output_bytes = b""
+            for row in rows:
+                output_bytes += row.chunk

Review Comment:
   we should use `join` instead, which is faster. will update



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1159434752


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -52,13 +52,18 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
     session.withActive {
 
       // Add debug information to the query execution so that the jobs are traceable.
-      val debugString = v.toString
-      session.sparkContext.setLocalProperty(
-        "callSite.short",
-        s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}")
-      session.sparkContext.setLocalProperty(
-        "callSite.long",
-        StringUtils.abbreviate(debugString, 2048))
+      try {

Review Comment:
   ```
   Started distributed training with 2 executor processes
   java.lang.OutOfMemoryError: Java heap space
   	at java.util.Arrays.copyOfRange(Arrays.java:3664)
   	at java.lang.String.<init>(String.java:207)
   	at java.lang.StringBuilder.toString(StringBuilder.java:407)
   	at org.sparkproject.connect.protobuf.TextFormatEscaper.escapeBytes(TextFormatEscaper.java:112)
   	at org.sparkproject.connect.protobuf.TextFormatEscaper.escapeBytes(TextFormatEscaper.java:119)
   	at org.sparkproject.connect.protobuf.TextFormat.escapeBytes(TextFormat.java:2364)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:593)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   Extracting /home/ruifeng.zheng/spark/python/target/50b3d81f-67a9-46de-9beb-59b733c16e54/tmp72g322ys/MNIST/raw/t10k-labels-idx1-ubyte.gz to /home/ruifeng.zheng/spark/python/target/50b3d81f-67a9-46de-9beb-59b733c16e54/tmp72g322ys/MNIST/raw
   ```
   
   `v.toString` is keeping throwing OOM in `TorchDistributorDistributedUnitTestsOnConnect`.
   This OOM seems related to Java Version, it was thrown in both Linux+Java8 and MacOS+Java8, but doesn't emerge in MacOS+Java11.
   
   
   The GA resources for free usage is limited to 2U 6G (confirmed with @Yikun), and I believe we cannot allocate enough driver memory for this distributed pytorch training UT without this fix.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1159434752


##########
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##########
@@ -52,13 +52,18 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
     session.withActive {
 
       // Add debug information to the query execution so that the jobs are traceable.
-      val debugString = v.toString
-      session.sparkContext.setLocalProperty(
-        "callSite.short",
-        s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}")
-      session.sparkContext.setLocalProperty(
-        "callSite.long",
-        StringUtils.abbreviate(debugString, 2048))
+      try {

Review Comment:
   ```
   Started distributed training with 2 executor processes
   java.lang.OutOfMemoryError: Java heap space
   	at java.util.Arrays.copyOfRange(Arrays.java:3664)
   	at java.lang.String.<init>(String.java:207)
   	at java.lang.StringBuilder.toString(StringBuilder.java:407)
   	at org.sparkproject.connect.protobuf.TextFormatEscaper.escapeBytes(TextFormatEscaper.java:112)
   	at org.sparkproject.connect.protobuf.TextFormatEscaper.escapeBytes(TextFormatEscaper.java:119)
   	at org.sparkproject.connect.protobuf.TextFormat.escapeBytes(TextFormat.java:2364)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:593)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printSingleField(TextFormat.java:752)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printField(TextFormat.java:457)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printMessage(TextFormat.java:714)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.print(TextFormat.java:367)
   	at org.sparkproject.connect.protobuf.TextFormat$Printer.printFieldValue(TextFormat.java:606)
   Extracting /home/ruifeng.zheng/spark/python/target/50b3d81f-67a9-46de-9beb-59b733c16e54/tmp72g322ys/MNIST/raw/t10k-labels-idx1-ubyte.gz to /home/ruifeng.zheng/spark/python/target/50b3d81f-67a9-46de-9beb-59b733c16e54/tmp72g322ys/MNIST/raw
   ```
   
   `v.toString` is keeping OOM in `TorchDistributorDistributedUnitTestsOnConnect`.
   This OOM is related to OS or Java Version, it was thrown in Linux+Java8, but doesn't emerge in my local env (macos+java11).
   
   
   The GA resources for free usage is limited to 2U 6G (confirmed with @Yikun), and I can not allocate enough driver memory for distributed pytorch training without this fix.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155430451


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -554,7 +568,7 @@ def set_gpus(context: "BarrierTaskContext") -> None:
                     pass
 
             if context.partitionId() == 0:
-                yield output
+                yield pd.DataFrame(data={"output": [cloudpickle.dumps(output)]})

Review Comment:
   Q: Did you test what happens if the output exceeding 2GB ? I guess it is not supported. Shall we split it in the case ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jaceklaskowski commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor support Spark Connect

Posted by "jaceklaskowski (via GitHub)" <gi...@apache.org>.
jaceklaskowski commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1154524221


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -581,12 +590,12 @@ def _run_distributed_training(
             f"Started distributed training with {self.num_processes} executor proceses"

Review Comment:
   s/proceses/processes



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1155430910


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -32,54 +32,65 @@
 
 from pyspark import cloudpickle
 from pyspark.sql import SparkSession
+from pyspark.taskcontext import BarrierTaskContext
 from pyspark.ml.torch.log_communication import (  # type: ignore
-    get_driver_host,
     LogStreamingClient,
     LogStreamingServer,
 )
-from pyspark.context import SparkContext
-from pyspark.taskcontext import BarrierTaskContext
 
 
-# TODO(SPARK-41589): will move the functions and tests to an external file
-#       once we are in agreement about which functions should be in utils.py
-def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
-    """Get the conf "key" from the given spark context,
+def _get_active_session() -> SparkSession:
+    from pyspark.sql.utils import is_remote
+
+    if not is_remote():
+        spark = SparkSession.getActiveSession()
+    else:
+        from pyspark.sql.connect.session import _active_spark_session
+
+        spark = _active_spark_session  # type: ignore[assignment]
+
+    if spark is None:
+        raise RuntimeError("An active SparkSession is required for the distributor.")
+    return spark
+
+
+def _get_conf(spark: SparkSession, key: str, default_value: str) -> str:
+    """Get the conf "key" from the given spark session,
     or return the default value if the conf is not set.
-    This expects the conf value to be a boolean or string;
-    if the value is a string, this checks for all capitalization
-    patterns of "true" and "false" to match Scala.
+    If this session is a remote connect session, the SparkConf
+    code path always fail, and fallback to the RuntimeConf.
 
     Parameters
     ----------
-    sc : :class:`SparkContext`
-        The :class:`SparkContext` for the distributor.
+    spark : :class:`SparkSession`
+        The :class:`SparkSession` for the distributor.
     key : str
         string for conf name
     default_value : str
         default value for the conf value for the given key
 
     Returns
     -------
-    bool
-        Returns the boolean value that corresponds to the conf
-
-    Raises
-    ------
-    ValueError
-        Thrown when the conf value is not a valid boolean
+    str
+        Returns the string value that corresponds to the conf
     """
-    val = sc.getConf().get(key, default_value)
-    lowercase_val = val.lower()
-    if lowercase_val == "true":
-        return True
-    if lowercase_val == "false":
-        return False
-    raise ValueError(
-        f"The conf value for '{key}' was expected to be a boolean "
-        f"value but found value of type {type(val)} "
-        f"with value: {val}"
-    )
+    from pyspark.sql.utils import is_remote
+
+    if not is_remote():
+        value = spark.sparkContext.getConf().get(key, default_value)

Review Comment:
   some tests in `test_distributor` directly change the cached conf `sc._conf`, they won't works with `spark.conf.get`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1156574015


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -554,7 +561,21 @@ def set_gpus(context: "BarrierTaskContext") -> None:
                     pass
 
             if context.partitionId() == 0:
-                yield output
+                output_bytes = cloudpickle.dumps(output)
+                output_size = len(output_bytes)
+
+                # In Spark Connect, DataFrame.collect stacks rows to size
+                # 'spark.connect.grpc.arrow.maxBatchSize' (default 4MiB),
+                # here use 4KiB for each chunk, which mean each arrow batch
+                # may contain about 1000 chunks.
+                chunks = []
+                chunk_size = 4096

Review Comment:
   the arrow batch size is actually controlled by `spark.connect.grpc.arrow.maxBatchSize`, and if it exceeds another conf (grpc...max_size) it fails.
   So I think we don't need a larger chunk size here



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1156751693


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -578,19 +600,23 @@ def _run_distributed_training(
         )
         self._check_encryption()
         self.logger.info(
-            f"Started distributed training with {self.num_processes} executor proceses"
+            f"Started distributed training with {self.num_processes} executor processes"
         )
         try:
-            result = (
-                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-                .barrier()
-                .mapPartitions(spark_task_function)
-                .collect()[0]
+            rows = (
+                self.spark.range(start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks)
+                .mapInPandas(func=spark_task_function, schema="chunk binary", barrier=True)
+                .collect()
             )
+            output_bytes = b""
+            for row in rows:
+                output_bytes += row.chunk

Review Comment:
   `join` cause memory issue on master ci, I have another try with `bytearray`
   
   ```
   pyspark.errors.exceptions.connect.SparkConnectGrpcException: <_MultiThreadedRendezvous of RPC that terminated with:
   	status = StatusCode.UNKNOWN
   	details = "Java heap space"
   	debug_error_string = "UNKNOWN:Error received from peer ipv4:127.0.0.1:35291 {created_time:"2023-04-04T05:17:10.693589954+00:00", grpc_status:2, grpc_message:"Java heap space"}"
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on PR #40607:
URL: https://github.com/apache/spark/pull/40607#issuecomment-1493485269

   cc @WeichenXu123 @HyukjinKwon  I think it is ready to review


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1156939935


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -578,19 +600,23 @@ def _run_distributed_training(
         )
         self._check_encryption()
         self.logger.info(
-            f"Started distributed training with {self.num_processes} executor proceses"
+            f"Started distributed training with {self.num_processes} executor processes"
         )
         try:
-            result = (
-                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-                .barrier()
-                .mapPartitions(spark_task_function)
-                .collect()[0]
+            rows = (
+                self.spark.range(start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks)
+                .mapInPandas(func=spark_task_function, schema="chunk binary", barrier=True)
+                .collect()
             )
+            output_bytes = b""
+            for row in rows:
+                output_bytes += row.chunk

Review Comment:
   That is true, it is weird since I can not repro it locally. Need future investigation.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [WIP][ML] Make Torch Distributor support Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1153066202


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -581,11 +593,11 @@ def _run_distributed_training(
             f"Started distributed training with {self.num_processes} executor proceses"
         )
         try:
+            assert self.spark is not None
             result = (
-                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-                .barrier()
-                .mapPartitions(spark_task_function)
-                .collect()[0]
+                self.spark.range(start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks)
+                .mapInPandas(func=spark_task_function, schema="output binary", barrier=True)
+                .first()["output"]

Review Comment:
   I think barrier mode mapping result dataframe does not support `.first()` operation if I remember it correctly



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #40607: [WIP][ML] Make Torch Distributor support Spark Connect

Posted by "WeichenXu123 (via GitHub)" <gi...@apache.org>.
WeichenXu123 commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1153067929


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -144,15 +145,21 @@ def __init__(
         num_processes: int = 1,
         local_mode: bool = True,
         use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
     ):
+        if spark is None:
+            self.spark = SparkSession.getActiveSession()

Review Comment:
   On spark connect client, we haven't supported `getActiveSession` ? But we can use `_active_spark_session` instead



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #40607: [WIP][ML] Make Torch Distributor support Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on code in PR #40607:
URL: https://github.com/apache/spark/pull/40607#discussion_r1153069439


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -330,6 +340,7 @@ def __init__(
         num_processes: int = 1,
         local_mode: bool = True,
         use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,

Review Comment:
   right now, just for test purpose, I want to specify the session in test for debug



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on pull request #40607: [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark Connect

Posted by "zhengruifeng (via GitHub)" <gi...@apache.org>.
zhengruifeng commented on PR #40607:
URL: https://github.com/apache/spark/pull/40607#issuecomment-1499783849

   Thanks all for the reviews, merged into master


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org