You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/12/21 01:25:27 UTC

[GitHub] [spark] rithwik-db opened a new pull request, #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

rithwik-db opened a new pull request, #39146:
URL: https://github.com/apache/spark/pull/39146

   <!--
   Thanks for sending a pull request!  Here are some tips for you:
     1. If this is your first time, please read our contributor guidelines: https://spark.apache.org/contributing.html
     2. Ensure you have added or run the appropriate tests for your PR: https://spark.apache.org/developer-tools.html
     3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][SPARK-XXXX] Your PR title ...'.
     4. Be sure to keep the PR description updated to reflect all changes.
     5. Please write your PR title to summarize what this PR proposes.
     6. If possible, provide a concise example to reproduce the issue for a faster review.
     7. If you want to add a new configuration, please read the guideline first for naming configurations in
        'core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala'.
     8. If you want to add or modify an error type or message, please read the guideline first in
        'core/src/main/resources/error/README.md'.
   -->
   
   Just creating a small PR to start progress on the Spark-PyTorch Distributor. This is a WIP project and I left questions and comments to discuss how I will be approaching certain aspects of the code.
   
   ### What changes were proposed in this pull request?
   <!--
   Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. 
   If possible, please consider writing useful notes for better and faster reviews in your PR. See the examples below.
     1. If you refactor some codes with changing classes, showing the class hierarchy will help reviewers.
     2. If you fix some SQL features, you can provide some references of other DBMSes.
     3. If there is design documentation, please add the link.
     4. If there is a discussion in the mailing list, please add the link.
   -->
   
   This just proposes the baseline API for how users will interact with the Spark PyTorch distributor ([Design Document](https://docs.google.com/document/d/1QPO1Ly8WteL6aIPvVcR7Xne9qVtJiB3fdrRn7NwBcpA/edit?usp=sharing)).
   
   ### Why are the changes needed?
   <!--
   Please clarify why the changes are needed. For instance,
     1. If you propose a new API, clarify the use case for a new API.
     2. If you fix a bug, you can clarify why it is a bug.
   -->
   
   The design document's [background](https://docs.google.com/document/d/1QPO1Ly8WteL6aIPvVcR7Xne9qVtJiB3fdrRn7NwBcpA/edit#heading=h.cxcvohcybvo2) section goes into more detail about the why.
   
   ### Does this PR introduce _any_ user-facing change?
   <!--
   Note that it means *any* user-facing change including all aspects such as the documentation fix.
   If yes, please clarify the previous behavior and the change this PR proposes - provide the console output, description and/or an example to show the behavior difference if possible.
   If possible, please also clarify if this is a user-facing change compared to the released Spark versions or within the unreleased branches such as master.
   If no, write 'No'.
   -->
   
   Yes, this proposes an [API](https://docs.google.com/document/d/1QPO1Ly8WteL6aIPvVcR7Xne9qVtJiB3fdrRn7NwBcpA/edit#heading=h.5ifxsbo5fk8d) for how users will interact with the PyTorch Distributor. The [user workflow](https://docs.google.com/document/d/1QPO1Ly8WteL6aIPvVcR7Xne9qVtJiB3fdrRn7NwBcpA/edit#heading=h.8yvw9xq428fh) is also proposed in that design document.
   
   
   ### How was this patch tested?
   <!--
   If tests were added, say they were added here. Please make sure to add some test cases that check the changes thoroughly including negative and positive cases if possible.
   If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future.
   If tests were not added, please describe why they were not added and/or why it was difficult to add.
   If benchmark tests were added, please run the benchmarks in GitHub Actions for the consistent environment, and the instructions could accord to: https://spark.apache.org/developer-tools.html#github-workflow-benchmarks.
   -->
   
   I just added some basic tests. These will need to be improved to correctly match the style that PySpark requires.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1053898579


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .utils import get_conf_boolean

Review Comment:
   let's use absolute import to be consistent across the codebase.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
WeichenXu123 commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1058886036


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# Moved the util functions to this file for now
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                if gpu_amount_raw := self.sc.getConf().get(key):  # mypy error??
+                    task_gpu_amount = int(gpu_amount_raw)
+            else:
+                task_gpu_amount = 1  # for single node clusters

Review Comment:
   Wait, seemingly we should change logic here to be:
   ```
   if self.sc.master.startswith("local"):
     if "gpu" not in self.sc.resources:
       raise RuntimeError(...)
     gpu_amount_raw = 1
   else:
     # if `spark.task.resource.gpu.amount` is not set, spark task
     # cannot use gpu
     gpu_amount_raw = int(self.sc.getConf().get("spark.task.resource.gpu.amount", "0"))
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063201881


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,193 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import stat
+import tempfile
+import unittest
+
+from pyspark import SparkConf, SparkContext
+from pyspark.ml.torch.distributor import TorchDistributor
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+
+
+class TorchDistributorBaselineUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        conf = SparkConf()
+        self.sc = SparkContext("local[4]", conf=conf)
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+    def test_validate_correct_inputs(self) -> None:
+        inputs = [
+            ("pytorch", 1, True, False),
+            ("pytorch", 100, True, False),
+            ("pytorch-lightning", 1, False, False),
+            ("pytorch-lightning", 100, False, False),
+        ]
+        for framework, num_processes, local_mode, use_gpu in inputs:
+            with self.subTest():
+                TorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_validate_incorrect_inputs(self) -> None:
+        inputs = [
+            ("tensorflow", 1, True, False, ValueError, "framework"),
+            ("pytroch", 100, True, False, ValueError, "framework"),
+            ("pytorchlightning", 1, False, False, ValueError, "framework"),
+            ("pytorch-lightning", 0, False, False, ValueError, "positive"),
+        ]
+        for framework, num_processes, local_mode, use_gpu, error, message in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(error, message):
+                    TorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_encryption_passes(self) -> None:
+        inputs = [
+            ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "true"),
+            ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "false"),
+            ("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "true"),
+        ]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
+            with self.subTest():
+                self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                distributor = TorchDistributor("pytorch", 1, True, False)
+                distributor._check_encryption()
+
+    def test_encryption_fails(self) -> None:
+        # this is the only combination that should fail
+        inputs = [("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "false")]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(Exception, "encryption"):
+                    self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                    self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                    distributor = TorchDistributor("pytorch", 1, True, False)
+                    distributor._check_encryption()
+
+    def test_get_num_tasks_fails(self) -> None:
+        inputs = [1, 5, 4]
+
+        # This is when the conf isn't set and we request GPUs
+        for num_processes in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(RuntimeError, "driver"):
+                    TorchDistributor("pytorch", num_processes, True, True)
+                with self.assertRaisesRegex(RuntimeError, "unset"):
+                    TorchDistributor("pytorch", num_processes, False, True)
+
+
+class TorchDistributorLocalUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.tempFile = tempfile.NamedTemporaryFile(delete=False)
+        self.tempFile.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.tempFile.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.tempFile.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+        conf = SparkConf().set("spark.test.home", SPARK_HOME)
+
+        conf = conf.set("spark.driver.resource.gpu.amount", "3")
+        conf = conf.set("spark.driver.resource.gpu.discoveryScript", self.tempFile.name)
+
+        self.sc = SparkContext("local-cluster[2,2,1024]", class_name, conf=conf)
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self) -> None:
+        os.unlink(self.tempFile.name)
+        self.spark.stop()
+
+    def test_get_num_tasks_locally(self) -> None:
+        succeeds = [1, 2]
+        fails = [4, 8]
+        for num_processes in succeeds:
+            with self.subTest():
+                expected_output = num_processes
+                distributor = TorchDistributor("pytorch", num_processes, True, True)
+                self.assertEqual(distributor._get_num_tasks(), expected_output)
+
+        for num_processes in fails:
+            with self.subTest():
+                with self.assertWarns(RuntimeWarning):
+                    distributor = TorchDistributor("pytorch", num_processes, True, True)
+                    distributor.num_processes = 3
+
+
+class TorchDistributorDistributedUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.tempFile = tempfile.NamedTemporaryFile(delete=False)
+        self.tempFile.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.tempFile.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.tempFile.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+        conf = SparkConf().set("spark.test.home", SPARK_HOME)
+
+        conf = conf.set("spark.worker.resource.gpu.discoveryScript", self.tempFile.name)
+        conf = conf.set("spark.worker.resource.gpu.amount", "3")
+        conf = conf.set("spark.task.cpus", "2")
+        conf = conf.set("spark.task.resource.gpu.amount", "1")
+        conf = conf.set("spark.executor.resource.gpu.amount", "1")
+
+        self.sc = SparkContext("local-cluster[2,2,1024]", class_name, conf=conf)
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self) -> None:
+        os.unlink(self.tempFile.name)
+        self.spark.stop()
+
+    def test_get_num_tasks_distributed(self) -> None:
+        inputs = [(1, 8, 8), (2, 8, 4), (3, 8, 3)]
+
+        for spark_conf_value, num_processes, expected_output in inputs:
+            with self.subTest():
+                self.spark.sparkContext._conf.set(
+                    "spark.task.resource.gpu.amount", str(spark_conf_value)
+                )
+                distributor = TorchDistributor("pytorch", num_processes, False, True)
+                self.assertEqual(distributor._get_num_tasks(), expected_output)
+
+        self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount", "1")
+
+
+if __name__ == "__main__":
+    from pyspark.ml.torch.tests.test_distributor import *  # noqa: F401,F403 type: ignore

Review Comment:
   Why we need to add F403? I saw other test files only include F401.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1055997628


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,82 @@
+from pyspark.testing.utils import PySparkTestCase
+from pyspark.ml.torch.distributor import PyTorchDistributor
+from pyspark.sql import SparkSession
+
+# Q: Do you recommend that I use pytest.mark.parametrize? It doesn't seem to be used elsewhere in this code...

Review Comment:
   Let's don't since we don't have pytest in PySpark.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1061990731


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# Moved the util functions to this file for now
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                if gpu_amount_raw := self.sc.getConf().get(key):  # mypy error??
+                    task_gpu_amount = int(gpu_amount_raw)
+            else:
+                task_gpu_amount = 1  # for single node clusters
+            if task_gpu_amount < 1:
+                raise ValueError(
+                    f"The Spark conf `{key}` has a value "
+                    f"of {task_gpu_amount} but it "
+                    "should not have a value less than 1."
+                )
+            return math.ceil(self.num_processes / task_gpu_amount)
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use PyTorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use PyTorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class PyTorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.

Review Comment:
   Yes we should right, its just tensorflow that's not supported...



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1053899152


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .utils import get_conf_boolean
+import math
+from typing import Union, Callable
+from pyspark.sql import SparkSession
+
+
+# might need to move into its own file as we look forward.
+class Distributor:
+    def __init__(self, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        
+    def _get_num_tasks(self):
+        """
+        Returns:
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                task_gpu_amount = int(self.sc.getConf().get(key))
+            else:
+                task_gpu_amount = 1 # for single node clusters
+            if task_gpu_amount < 1:
+                raise ValueError(
+                    f"The Spark conf `{key}` has a value "
+                    f"of {task_gpu_amount} but it "
+                    "should not have a value less than 1."
+                )
+            return math.ceil(self.num_processes / task_gpu_amount)
+        return self.num_processes
+    
+    def _validate_input_params(self):
+        if self.num_processes <= 0:
+            raise ValueError(f"num_proccesses has to be a positive integer")
+
+    # can be required for TF distributor in the future as well
+    def _check_encryption(self):
+        if "ssl_conf" not in self.__dict__:
+            raise NotImplementedError()
+        is_ssl_enabled = get_conf_boolean(
+            self.sc,
+            "spark.ssl.enabled",
+            "false"
+        )
+        ignore_ssl = get_conf_boolean(
+            self.sc,
+            self.ssl_conf,
+            "false"
+        )
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                print(

Review Comment:
   you can use `warnings.warn` for now to be consistent with other places



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1053899418


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,82 @@
+from pyspark.testing.utils import PySparkTestCase
+from pyspark.ml.torch.distributor import PyTorchDistributor
+from pyspark.sql import SparkSession
+
+# Q: Do you recommend that I use pytest.mark.parametrize? It doesn't seem to be used elsewhere in this code...
+class TestPyTorchDistributor(PySparkTestCase):
+
+    def setUp(self):
+        super().setUp()
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self):
+        super().tearDown()
+
+    def test_validate_correct_inputs(self):

Review Comment:
   we don't have it for now.  It would have to be manually parameterized as you did.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1055806728


##########
python/pyspark/ml/torch/utils.py:
##########
@@ -0,0 +1,24 @@
+from pyspark.context import SparkContext

Review Comment:
   I will actually delete this file for now since I don't expect a lot of utils just yet and if we do, we can move it out to `utils.py` at a later time.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1059095881


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# Moved the util functions to this file for now
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                if gpu_amount_raw := self.sc.getConf().get(key):  # mypy error??
+                    task_gpu_amount = int(gpu_amount_raw)
+            else:
+                task_gpu_amount = 1  # for single node clusters
+            if task_gpu_amount < 1:
+                raise ValueError(
+                    f"The Spark conf `{key}` has a value "
+                    f"of {task_gpu_amount} but it "
+                    "should not have a value less than 1."
+                )
+            return math.ceil(self.num_processes / task_gpu_amount)

Review Comment:
   The few situations:
   1. spark.task.resource.gpu.amount > 1 + user code can use 1 gpu core: This is fine, since for torchrun, we could do `torchrun nproc_per_node=task_gpu_amount ...` as the command to be executed. This will launch multiple training processes per task. This will need some additional work toward the end, so I will create this as a backlog ticket.
   2. spark.task.resource.gpu.amount = 1 + user code can use > 1 gpu core: This means using [model-parallel](https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html) instead of [distributed-data-parallel](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). It seems that currently, distributed training only supports data parallel and not model parallel so users would hopefully know not to use model parallel. We could log this as well though.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1059095881


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# Moved the util functions to this file for now
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                if gpu_amount_raw := self.sc.getConf().get(key):  # mypy error??
+                    task_gpu_amount = int(gpu_amount_raw)
+            else:
+                task_gpu_amount = 1  # for single node clusters
+            if task_gpu_amount < 1:
+                raise ValueError(
+                    f"The Spark conf `{key}` has a value "
+                    f"of {task_gpu_amount} but it "
+                    "should not have a value less than 1."
+                )
+            return math.ceil(self.num_processes / task_gpu_amount)

Review Comment:
   The few situations:
   1. spark.task.resource.gpu.amount > 1 + user code can use 1 gpu core: This is fine, since for torchrun, we could do `torchrun nproc_per_node=task_gpu_amount ...` as the command to be executed. This will launch multiple training processes per task.
   2. spark.task.resource.gpu.amount = 1 + user code can use > 1 gpu core: This means using [model-parallel](https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html) instead of [distributed-data-parallel](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). It seems that currently, distributed training only supports data parallel and not model parallel so users would hopefully know not to use model parallel. We could log this as well though.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063017746


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,195 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# from pyspark.testing.sqlutils import ReusedSQLTestCase
+import os
+from pyspark import SparkConf, SparkContext
+from pyspark.ml.torch.distributor import TorchDistributor
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+import stat
+import tempfile
+import unittest
+
+
+class TorchDistributorBaselineUnitTests(unittest.TestCase):

Review Comment:
   No, all tests will pass regardless of whether `pytorch` is installed until we get create a PR for https://issues.apache.org/jira/browse/SPARK-41777 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on pull request #39146: [SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on PR #39146:
URL: https://github.com/apache/spark/pull/39146#issuecomment-1376569489

   Test failures are not related to this PR.
   
   Merged to master.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063038821


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(

Review Comment:
   @HyukjinKwon What is the tradition for spark/pyspark handle invalid conf?
   
   I think ignore the invalid conf instead of throwing an error may be a proper way.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063153907


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,307 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    RuntimeError
+        Thrown when the conf value is not a valid boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise RuntimeError(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        self.spark = SparkSession.getActiveSession()
+        if not self.spark:
+            raise RuntimeError("An active SparkSession is required for the distributor.")
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+
+        Raises
+        ------
+        RuntimeError
+            Raised when the SparkConf was misconfigured.
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if num_available_gpus == 0:
+                    raise RuntimeError("GPU resources were not configured properly on the driver.")
+                if self.num_processes > num_available_gpus:
+                    warnings.warn(
+                        f"'num_processes' cannot be set to a value greater than the number of "
+                        f"available GPUs on the driver, which is {num_available_gpus}. "
+                        f"'num_processes' was reset to be equal to the number of available GPUs.",
+                        RuntimeWarning,
+                    )
+                    self.num_processes = num_available_gpus
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        RuntimeError
+            Thrown when the user requires ssl encryption or when the user initializes
+            the Distributor parent class.
+        """
+        if not "ssl_conf":
+            raise RuntimeError(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise RuntimeError(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    ...     import torch.distributed
+    ...     torch.distributed.init_process_group(backend="nccl")
+    ...     # ...
+    ...     torch.destroy_process_group()
+    ...     return model # or anything else
+    >>> distributor = TorchDistributor(
+    ...     framework="pytorch",
+    ...     num_processes=2,
+    ...     local_mode=True,
+    ...     use_gpu=True)
+    >>> model = distributor.run(train, 1e-3)
+
+    Run PyTorch Training on GPU (using a file with PyTorch code)
+
+    >>> distributor = TorchDistributor(
+    ...     framework="pytorch",
+    ...     num_processes=2,
+    ...     local_mode=False,
+    ...     use_gpu=True)
+    >>> distributor.run("/path/to/train.py", *args)
+
+    Run PyTorch Lightning Training on GPU
+
+    >>> num_proc = 2
+    >>> def train():
+    ...     from pytorch_lightning import Trainer
+    ...     # ...
+    ...     # required to set devices = 1 and num_nodes == num_processes
+    ...     trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp")
+    ...     trainer.fit()
+    ...     # ...
+    ...     return trainer
+    >>> distributor = TorchDistributor(
+    ...     framework="pytorch-lightning",
+    ...     num_processes=num_proc,
+    ...     local_mode=True,
+    ...     use_gpu=True)
+    >>> trainer = distributor.run(train)
+    """
+
+    # TODO(SPARK-41915): Remove need for setting frameworks in a future PR.
+    available_frameworks = ["pytorch", "pytorch-lightning"]
+
+    def __init__(
+        self,
+        framework: str,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+    ):
+        """Initializes the distributor.
+
+        Parameters
+        ----------
+        framework : str
+            A string indicating whether or not we are using PyTorch or PyTorch
+            Lightning. This could either be the string “pytorch” or ”pytorch-lightning”.
+        num_processes : int, optional
+            An integer that determines how many different concurrent
+            tasks are allowed. We expect spark.task.gpus = 1 for GPU-enabled training. Default
+            should be 1; we don't want to invoke multiple cores/gpus without explicit mention.
+        local_mode : bool, optional
+            A boolean that determines whether we are using the driver
+            node for training. Default should be false; we don't want to invoke executors without
+            explicit mention.
+        use_gpu : bool, optional
+            A boolean that indicates whether or not we are doing training
+            on the GPU. Note that there are differences in how GPU-enabled code looks like and
+            how CPU-specific code looks like.
+
+        Raises
+        ------
+        ValueError
+            If any of the parameters are incorrect.
+        RuntimeError
+            If an active SparkSession is unavailable.
+        """
+        # TODO(SPARK-41915): Remove framework requirement in a future PR.
+        super().__init__(num_processes, local_mode, use_gpu)
+        self.framework = framework
+        self.ssl_conf = "pytorch.spark.distributor.ignoreSsl"  # type: ignore
+        self._validate_input_params()
+
+    def _validate_input_params(self) -> None:
+        """Validates input params
+
+        Raises
+        ------
+        ValueError
+            Thrown when user fails to provide correct input params
+        """
+        super()._validate_input_params()
+        if self.framework not in self.available_frameworks:
+            raise ValueError(
+                f"{self.framework} is not a valid framework."
+                f"Available frameworks: {self.available_frameworks}"
+            )
+
+    def run(self, train_object: Union[Callable, str], *args: Any) -> Optional[Any]:
+        """Runs distributed training.
+
+        Parameters
+        ----------
+        train_object : callable object or str
+            Either a PyTorch/PyTorch Lightning training function or the path to a python file
+            that launches distributed training.
+        args : *args
+            The arguments for train_object
+
+        Returns
+        -------
+        Optional[Any]
+            Returns the output of train_object(*args) if train_object is aCallable with an

Review Comment:
   ```suggestion
               Returns the output of train_object(*args) if train_object is a Callable with an
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063163530


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,307 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    RuntimeError
+        Thrown when the conf value is not a valid boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise RuntimeError(
+        "get_conf_boolean expected a boolean conf "

Review Comment:
   Ok will update other errors as well.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063163530


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,307 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    RuntimeError
+        Thrown when the conf value is not a valid boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise RuntimeError(
+        "get_conf_boolean expected a boolean conf "

Review Comment:
   Ok, thanks for the clarification! :) 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063774097


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,193 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import stat
+import tempfile
+import unittest
+
+from pyspark import SparkConf, SparkContext
+from pyspark.ml.torch.distributor import TorchDistributor
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+
+
+class TorchDistributorBaselineUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        conf = SparkConf()
+        self.sc = SparkContext("local[4]", conf=conf)
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+    def test_validate_correct_inputs(self) -> None:
+        inputs = [
+            ("pytorch", 1, True, False),
+            ("pytorch", 100, True, False),
+            ("pytorch-lightning", 1, False, False),
+            ("pytorch-lightning", 100, False, False),
+        ]
+        for framework, num_processes, local_mode, use_gpu in inputs:
+            with self.subTest():
+                TorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_validate_incorrect_inputs(self) -> None:
+        inputs = [
+            ("tensorflow", 1, True, False, ValueError, "framework"),
+            ("pytroch", 100, True, False, ValueError, "framework"),
+            ("pytorchlightning", 1, False, False, ValueError, "framework"),
+            ("pytorch-lightning", 0, False, False, ValueError, "positive"),
+        ]
+        for framework, num_processes, local_mode, use_gpu, error, message in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(error, message):
+                    TorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_encryption_passes(self) -> None:
+        inputs = [
+            ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "true"),
+            ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "false"),
+            ("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "true"),
+        ]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
+            with self.subTest():
+                self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                distributor = TorchDistributor("pytorch", 1, True, False)
+                distributor._check_encryption()
+
+    def test_encryption_fails(self) -> None:
+        # this is the only combination that should fail
+        inputs = [("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "false")]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(Exception, "encryption"):
+                    self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                    self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                    distributor = TorchDistributor("pytorch", 1, True, False)
+                    distributor._check_encryption()
+
+    def test_get_num_tasks_fails(self) -> None:
+        inputs = [1, 5, 4]
+
+        # This is when the conf isn't set and we request GPUs
+        for num_processes in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(RuntimeError, "driver"):
+                    TorchDistributor("pytorch", num_processes, True, True)
+                with self.assertRaisesRegex(RuntimeError, "unset"):
+                    TorchDistributor("pytorch", num_processes, False, True)
+
+
+class TorchDistributorLocalUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.tempFile = tempfile.NamedTemporaryFile(delete=False)
+        self.tempFile.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.tempFile.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.tempFile.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+        conf = SparkConf().set("spark.test.home", SPARK_HOME)
+
+        conf = conf.set("spark.driver.resource.gpu.amount", "3")
+        conf = conf.set("spark.driver.resource.gpu.discoveryScript", self.tempFile.name)
+
+        self.sc = SparkContext("local-cluster[2,2,1024]", class_name, conf=conf)
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self) -> None:
+        os.unlink(self.tempFile.name)
+        self.spark.stop()
+
+    def test_get_num_tasks_locally(self) -> None:
+        succeeds = [1, 2]
+        fails = [4, 8]
+        for num_processes in succeeds:
+            with self.subTest():
+                expected_output = num_processes
+                distributor = TorchDistributor("pytorch", num_processes, True, True)
+                self.assertEqual(distributor._get_num_tasks(), expected_output)
+
+        for num_processes in fails:
+            with self.subTest():
+                with self.assertWarns(RuntimeWarning):
+                    distributor = TorchDistributor("pytorch", num_processes, True, True)
+                    distributor.num_processes = 3
+
+
+class TorchDistributorDistributedUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.tempFile = tempfile.NamedTemporaryFile(delete=False)
+        self.tempFile.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.tempFile.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.tempFile.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+        conf = SparkConf().set("spark.test.home", SPARK_HOME)
+
+        conf = conf.set("spark.worker.resource.gpu.discoveryScript", self.tempFile.name)
+        conf = conf.set("spark.worker.resource.gpu.amount", "3")
+        conf = conf.set("spark.task.cpus", "2")
+        conf = conf.set("spark.task.resource.gpu.amount", "1")
+        conf = conf.set("spark.executor.resource.gpu.amount", "1")
+
+        self.sc = SparkContext("local-cluster[2,2,1024]", class_name, conf=conf)
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self) -> None:
+        os.unlink(self.tempFile.name)
+        self.spark.stop()
+
+    def test_get_num_tasks_distributed(self) -> None:
+        inputs = [(1, 8, 8), (2, 8, 4), (3, 8, 3)]
+
+        for spark_conf_value, num_processes, expected_output in inputs:
+            with self.subTest():
+                self.spark.sparkContext._conf.set(
+                    "spark.task.resource.gpu.amount", str(spark_conf_value)
+                )
+                distributor = TorchDistributor("pytorch", num_processes, False, True)
+                self.assertEqual(distributor._get_num_tasks(), expected_output)
+
+        self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount", "1")
+
+
+if __name__ == "__main__":
+    from pyspark.ml.torch.tests.test_distributor import *  # noqa: F401,F403 type: ignore

Review Comment:
   `mypy` raises errors otherwise



##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,193 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import stat
+import tempfile
+import unittest
+
+from pyspark import SparkConf, SparkContext
+from pyspark.ml.torch.distributor import TorchDistributor
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+
+
+class TorchDistributorBaselineUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        conf = SparkConf()
+        self.sc = SparkContext("local[4]", conf=conf)
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+    def test_validate_correct_inputs(self) -> None:
+        inputs = [
+            ("pytorch", 1, True, False),
+            ("pytorch", 100, True, False),
+            ("pytorch-lightning", 1, False, False),
+            ("pytorch-lightning", 100, False, False),
+        ]
+        for framework, num_processes, local_mode, use_gpu in inputs:
+            with self.subTest():
+                TorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_validate_incorrect_inputs(self) -> None:
+        inputs = [
+            ("tensorflow", 1, True, False, ValueError, "framework"),
+            ("pytroch", 100, True, False, ValueError, "framework"),
+            ("pytorchlightning", 1, False, False, ValueError, "framework"),
+            ("pytorch-lightning", 0, False, False, ValueError, "positive"),
+        ]
+        for framework, num_processes, local_mode, use_gpu, error, message in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(error, message):
+                    TorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_encryption_passes(self) -> None:
+        inputs = [
+            ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "true"),
+            ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "false"),
+            ("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "true"),
+        ]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
+            with self.subTest():
+                self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                distributor = TorchDistributor("pytorch", 1, True, False)
+                distributor._check_encryption()
+
+    def test_encryption_fails(self) -> None:
+        # this is the only combination that should fail
+        inputs = [("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "false")]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(Exception, "encryption"):
+                    self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                    self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                    distributor = TorchDistributor("pytorch", 1, True, False)
+                    distributor._check_encryption()
+
+    def test_get_num_tasks_fails(self) -> None:
+        inputs = [1, 5, 4]
+
+        # This is when the conf isn't set and we request GPUs
+        for num_processes in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(RuntimeError, "driver"):
+                    TorchDistributor("pytorch", num_processes, True, True)
+                with self.assertRaisesRegex(RuntimeError, "unset"):
+                    TorchDistributor("pytorch", num_processes, False, True)
+
+
+class TorchDistributorLocalUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.tempFile = tempfile.NamedTemporaryFile(delete=False)
+        self.tempFile.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.tempFile.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.tempFile.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+        conf = SparkConf().set("spark.test.home", SPARK_HOME)
+
+        conf = conf.set("spark.driver.resource.gpu.amount", "3")
+        conf = conf.set("spark.driver.resource.gpu.discoveryScript", self.tempFile.name)
+
+        self.sc = SparkContext("local-cluster[2,2,1024]", class_name, conf=conf)
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self) -> None:
+        os.unlink(self.tempFile.name)
+        self.spark.stop()
+
+    def test_get_num_tasks_locally(self) -> None:
+        succeeds = [1, 2]
+        fails = [4, 8]
+        for num_processes in succeeds:
+            with self.subTest():
+                expected_output = num_processes
+                distributor = TorchDistributor("pytorch", num_processes, True, True)
+                self.assertEqual(distributor._get_num_tasks(), expected_output)
+
+        for num_processes in fails:
+            with self.subTest():
+                with self.assertWarns(RuntimeWarning):
+                    distributor = TorchDistributor("pytorch", num_processes, True, True)
+                    distributor.num_processes = 3
+
+
+class TorchDistributorDistributedUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.tempFile = tempfile.NamedTemporaryFile(delete=False)
+        self.tempFile.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.tempFile.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.tempFile.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+        conf = SparkConf().set("spark.test.home", SPARK_HOME)
+
+        conf = conf.set("spark.worker.resource.gpu.discoveryScript", self.tempFile.name)
+        conf = conf.set("spark.worker.resource.gpu.amount", "3")
+        conf = conf.set("spark.task.cpus", "2")
+        conf = conf.set("spark.task.resource.gpu.amount", "1")
+        conf = conf.set("spark.executor.resource.gpu.amount", "1")
+
+        self.sc = SparkContext("local-cluster[2,2,1024]", class_name, conf=conf)
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self) -> None:
+        os.unlink(self.tempFile.name)
+        self.spark.stop()
+
+    def test_get_num_tasks_distributed(self) -> None:
+        inputs = [(1, 8, 8), (2, 8, 4), (3, 8, 3)]
+
+        for spark_conf_value, num_processes, expected_output in inputs:
+            with self.subTest():
+                self.spark.sparkContext._conf.set(
+                    "spark.task.resource.gpu.amount", str(spark_conf_value)
+                )
+                distributor = TorchDistributor("pytorch", num_processes, False, True)
+                self.assertEqual(distributor._get_num_tasks(), expected_output)
+
+        self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount", "1")
+
+
+if __name__ == "__main__":
+    from pyspark.ml.torch.tests.test_distributor import *  # noqa: F401,F403 type: ignore
+
+    try:
+        import xmlrunner  # type: ignore

Review Comment:
   `mypy` raises errors otherwise



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
WeichenXu123 commented on PR #39146:
URL: https://github.com/apache/spark/pull/39146#issuecomment-1373104208

   > Note: there are a lot of relevant comments on #39188. Should I make those changes here or should I make those changes separately in that PR (to make sure the comment history makes sense)? @WeichenXu123
   
   On which PR the code is more related, you should put the updated code there.
   and you can link comment in other PRs if needed.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1054100034


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .utils import get_conf_boolean
+import math
+from typing import Union, Callable
+from pyspark.sql import SparkSession
+
+
+# might need to move into its own file as we look forward.
+class Distributor:
+    def __init__(self, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        self.spark = SparkSession.builder.getOrCreate()

Review Comment:
   are there cases that users may want to specify the SparkSession?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063014563


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(

Review Comment:
   will make it a `RuntimeError`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063017531


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    >>>     import torch.distributed
+    >>>     torch.distributed.init_process_group(backend="nccl")
+    >>>     ...
+    >>>     torch.destroy_process_group()
+    >>>     return model # or anything else
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=True,
+                                  use_gpu=True)
+    >>> model = distributor.run(train, 1e-3)
+
+    Run PyTorch Training on GPU (using a file with PyTorch code)
+
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=False,
+                                  use_gpu=True)
+    >>> distributor.run("/path/to/train.py", *args)
+
+    Run PyTorch Lightning Training

Review Comment:
   ```suggestion
       Run PyTorch Lightning Training
   
   ```
   
   Otherwise, documentation rendering is broken.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063666448


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:

Review Comment:
   In this case, it is fine with me to write a new function. Just consider to make a general utils function for strtobool.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063117818


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(

Review Comment:
   Sure, that seems reasonable.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063124224


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))

Review Comment:
   Sure, will add that check.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063200072


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    >>>     import torch.distributed
+    >>>     torch.distributed.init_process_group(backend="nccl")
+    >>>     ...
+    >>>     torch.destroy_process_group()
+    >>>     return model # or anything else
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=True,
+                                  use_gpu=True)
+    >>> model = distributor.run(train, 1e-3)
+
+    Run PyTorch Training on GPU (using a file with PyTorch code)
+
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=False,
+                                  use_gpu=True)
+    >>> distributor.run("/path/to/train.py", *args)
+
+    Run PyTorch Lightning Training
+    >>> num_proc = 2
+    >>> def train():
+    >>>     from pytorch_lightning import Trainer
+    >>>     ...
+    >>>     # required to set devices = 1 and num_nodes == num_processes
+    >>>     trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp")
+    >>>     trainer.fit()
+    >>>     ...
+    >>>     return trainer
+    >>> distributor = TorchDistributor(framework="pytorch-lightning",
+                                  num_processes=num_proc,
+                                  local_mode=True,
+                                  use_gpu=True)
+    >>> trainer = distributor.run(train)
+    """
+
+    available_frameworks = ["pytorch", "pytorch-lightning"]

Review Comment:
   Is it complex to remove this? If not, I will prefer to remove it from this PR.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
WeichenXu123 commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1064528172


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:

Review Comment:
   @HyukjinKwon Could you add pyspark helper method for handling bool config ? No need to block this PR.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1054097560


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,82 @@
+from pyspark.testing.utils import PySparkTestCase
+from pyspark.ml.torch.distributor import PyTorchDistributor
+from pyspark.sql import SparkSession
+
+# Q: Do you recommend that I use pytest.mark.parametrize? It doesn't seem to be used elsewhere in this code...
+class TestPyTorchDistributor(PySparkTestCase):
+
+    def setUp(self):
+        super().setUp()
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self):
+        super().tearDown()
+
+    def test_validate_correct_inputs(self):
+        inputs = [("pytorch", 1, True, True),
+                  ("pytorch", 100, True, False),
+                  ("pytorch-lightning", 1, False, True),
+                  ("pytorch-lightning", 100, False, False)]
+        for framework, num_processes, local_mode, use_gpu in inputs:
+            with self.subTest():
+                PyTorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_validate_incorrect_inputs(self):
+        inputs = [("tensorflow", 1, True, True, ValueError, "framework"),
+                  ("pytroch", 100, True, False, ValueError, "framework"),
+                  ("pytorchlightning", 1, False, True, ValueError, "framework"),
+                  ("pytorch-lightning", 0, False, False, ValueError, "positive")]
+        for framework, num_processes, local_mode, use_gpu, error, message in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(error, message):
+                    PyTorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_get_correct_num_tasks_when_spark_conf_is_set(self):
+        inputs = [(1, 8, 8),
+                  (2, 8, 4),
+                  (3, 8, 3)]
+        # this is when the sparkconf isn't set
+        for _, num_processes, _ in inputs:
+            with self.subTest():
+                distributor = PyTorchDistributor("pytorch", num_processes, True, True)
+                self.assertEqual(distributor._get_num_tasks(), num_processes)
+        
+        # this is when the sparkconf is set
+        for spark_conf_value, num_processes, expected_output in inputs:
+            with self.subTest():
+                self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount", spark_conf_value)
+                distributor = PyTorchDistributor("pytorch", num_processes, True, True)
+                self.assertEqual(distributor._get_num_tasks(), expected_output)
+
+    def test_encryption_passes(self):
+        input_combination = [("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "true"),
+                             ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "false"),
+                             ("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "true")]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in input_combination:
+            with self.subTest():
+                self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                distributor = PyTorchDistributor("pytorch", 1, True, True)
+                distributor._check_encryption()
+
+    def test_encryption_fails(self):
+        # this is the only combination that should fail
+        input_combination = [("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "false")]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in input_combination:
+            with self.subTest():
+                with self.assertRaisesRegex(Exception, "encryption"):
+                    self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                    self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                    distributor = PyTorchDistributor("pytorch", 1, True, True)
+                    distributor._check_encryption()
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_util import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)

Review Comment:
   +1



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1054095381


##########
python/pyspark/ml/torch/utils.py:
##########
@@ -0,0 +1,24 @@
+from pyspark.context import SparkContext
+
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str):
+        """
+        Get the conf "key" from the given spark context,
+        or return the default value if the conf is not set.
+        This expects the conf value to be a boolean or string;
+        if the value is a string, this checks for all capitalization
+        patterns of "true" and "false" to match Scala.
+        Args:
+            key: string for conf name
+            default_value: default value for the conf value for the given key
+        """
+        val = sc.getConf().get(key, default_value)
+        lowercase_val = val.lower()
+        if lowercase_val == "true":
+            return True
+        if lowercase_val == "false":
+            return False
+        raise Exception(
+            "_getConfBoolean expected a boolean conf "

Review Comment:
   ```suggestion
               "get_conf_boolean expected a boolean conf "
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1056647526


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,82 @@
+from pyspark.testing.utils import PySparkTestCase
+from pyspark.ml.torch.distributor import PyTorchDistributor
+from pyspark.sql import SparkSession
+
+# Q: Do you recommend that I use pytest.mark.parametrize? It doesn't seem to be used elsewhere in this code...
+class TestPyTorchDistributor(PySparkTestCase):

Review Comment:
   I want to start and stop a SparkSession instance after each test, not when the class starts and ends. Correct me if I am wrong, but `ReusedSQLTestCase` seems to do the setup and teardown only when a class starts and ends, not during each individual test.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1059095881


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# Moved the util functions to this file for now
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                if gpu_amount_raw := self.sc.getConf().get(key):  # mypy error??
+                    task_gpu_amount = int(gpu_amount_raw)
+            else:
+                task_gpu_amount = 1  # for single node clusters
+            if task_gpu_amount < 1:
+                raise ValueError(
+                    f"The Spark conf `{key}` has a value "
+                    f"of {task_gpu_amount} but it "
+                    "should not have a value less than 1."
+                )
+            return math.ceil(self.num_processes / task_gpu_amount)

Review Comment:
   The few situations:
   1. spark.task.resource.gpu.amount > 1 + user code can use 1 gpu core: 1 gpu core is used for training so this will waste resources. We could log this so the user will be aware of that fact.
   2. spark.task.resource.gpu.amount = 1 + user code can use > 1 gpu core: This means using [model-parallel](https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html) instead of [distributed-data-parallel](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). It seems that currently, distributed training only supports data parallel and not model parallel so users would hopefully know not to use model parallel. We could log this as well though.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063014892


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    >>>     import torch.distributed
+    >>>     torch.distributed.init_process_group(backend="nccl")
+    >>>     ...
+    >>>     torch.destroy_process_group()
+    >>>     return model # or anything else
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=True,
+                                  use_gpu=True)
+    >>> model = distributor.run(train, 1e-3)
+
+    Run PyTorch Training on GPU (using a file with PyTorch code)
+
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=False,
+                                  use_gpu=True)
+    >>> distributor.run("/path/to/train.py", *args)
+
+    Run PyTorch Lightning Training
+    >>> num_proc = 2
+    >>> def train():
+    >>>     from pytorch_lightning import Trainer
+    >>>     ...
+    >>>     # required to set devices = 1 and num_nodes == num_processes
+    >>>     trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp")
+    >>>     trainer.fit()
+    >>>     ...
+    >>>     return trainer
+    >>> distributor = TorchDistributor(framework="pytorch-lightning",
+                                  num_processes=num_proc,
+                                  local_mode=True,
+                                  use_gpu=True)

Review Comment:
   ```suggestion
       >>> distributor = TorchDistributor(
       ...     framework="pytorch-lightning",
       ...     num_processes=num_proc,
       ...     local_mode=True,
       ...     use_gpu=True)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063016151


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,195 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# from pyspark.testing.sqlutils import ReusedSQLTestCase
+import os
+from pyspark import SparkConf, SparkContext
+from pyspark.ml.torch.distributor import TorchDistributor
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+import stat
+import tempfile
+import unittest
+
+
+class TorchDistributorBaselineUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        conf = SparkConf()
+        self.sc = SparkContext("local[4]", conf=conf)
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+        self.sc.stop()

Review Comment:
   ```suggestion
   ```
   
   Stopping SparkSession stops SparkContext



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on PR #39146:
URL: https://github.com/apache/spark/pull/39146#issuecomment-1373000167

   Note: there are a lot of relevant comments on https://github.com/apache/spark/pull/39188. Should I make those changes here or should I make those changes separately once that PR gets merged in (to make sure the comment history makes sense)? @WeichenXu123 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on PR #39146:
URL: https://github.com/apache/spark/pull/39146#issuecomment-1373002232

   Assuming that adding pytorch to the CI will be done in a separate PR, I don't have any comment. I would defer to @mengxr @WeichenXu123 @zhengruifeng to sign-off.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063042745


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    >>>     import torch.distributed
+    >>>     torch.distributed.init_process_group(backend="nccl")
+    >>>     ...
+    >>>     torch.destroy_process_group()
+    >>>     return model # or anything else
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=True,
+                                  use_gpu=True)
+    >>> model = distributor.run(train, 1e-3)
+
+    Run PyTorch Training on GPU (using a file with PyTorch code)
+
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=False,
+                                  use_gpu=True)
+    >>> distributor.run("/path/to/train.py", *args)
+
+    Run PyTorch Lightning Training

Review Comment:
   To make it consistent:
   ```
   Run PyTorch Lightning Training on GPU
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063155314


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,193 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# from pyspark.testing.sqlutils import ReusedSQLTestCase
+import os
+from pyspark import SparkConf, SparkContext
+from pyspark.ml.torch.distributor import TorchDistributor
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+import stat
+import tempfile
+import unittest

Review Comment:
   ```suggestion
   import os
   import stat
   import tempfile
   import unittest
   
   from pyspark import SparkConf, SparkContext
   from pyspark.ml.torch.distributor import TorchDistributor
   from pyspark.sql import SparkSession
   from pyspark.testing.utils import SPARK_HOME
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063154959


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,307 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    RuntimeError
+        Thrown when the conf value is not a valid boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise RuntimeError(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        self.spark = SparkSession.getActiveSession()
+        if not self.spark:
+            raise RuntimeError("An active SparkSession is required for the distributor.")
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+
+        Raises
+        ------
+        RuntimeError
+            Raised when the SparkConf was misconfigured.
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if num_available_gpus == 0:
+                    raise RuntimeError("GPU resources were not configured properly on the driver.")
+                if self.num_processes > num_available_gpus:
+                    warnings.warn(
+                        f"'num_processes' cannot be set to a value greater than the number of "
+                        f"available GPUs on the driver, which is {num_available_gpus}. "
+                        f"'num_processes' was reset to be equal to the number of available GPUs.",
+                        RuntimeWarning,
+                    )
+                    self.num_processes = num_available_gpus
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        RuntimeError
+            Thrown when the user requires ssl encryption or when the user initializes
+            the Distributor parent class.
+        """
+        if not "ssl_conf":
+            raise RuntimeError(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise RuntimeError(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    ...     import torch.distributed
+    ...     torch.distributed.init_process_group(backend="nccl")
+    ...     # ...
+    ...     torch.destroy_process_group()
+    ...     return model # or anything else
+    >>> distributor = TorchDistributor(
+    ...     framework="pytorch",
+    ...     num_processes=2,
+    ...     local_mode=True,
+    ...     use_gpu=True)
+    >>> model = distributor.run(train, 1e-3)
+
+    Run PyTorch Training on GPU (using a file with PyTorch code)
+
+    >>> distributor = TorchDistributor(
+    ...     framework="pytorch",
+    ...     num_processes=2,
+    ...     local_mode=False,
+    ...     use_gpu=True)
+    >>> distributor.run("/path/to/train.py", *args)
+
+    Run PyTorch Lightning Training on GPU
+
+    >>> num_proc = 2
+    >>> def train():
+    ...     from pytorch_lightning import Trainer
+    ...     # ...
+    ...     # required to set devices = 1 and num_nodes == num_processes
+    ...     trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp")
+    ...     trainer.fit()
+    ...     # ...
+    ...     return trainer
+    >>> distributor = TorchDistributor(
+    ...     framework="pytorch-lightning",
+    ...     num_processes=num_proc,
+    ...     local_mode=True,
+    ...     use_gpu=True)
+    >>> trainer = distributor.run(train)
+    """
+
+    # TODO(SPARK-41915): Remove need for setting frameworks in a future PR.
+    available_frameworks = ["pytorch", "pytorch-lightning"]
+
+    def __init__(
+        self,
+        framework: str,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+    ):
+        """Initializes the distributor.
+
+        Parameters
+        ----------
+        framework : str
+            A string indicating whether or not we are using PyTorch or PyTorch
+            Lightning. This could either be the string “pytorch” or ”pytorch-lightning”.
+        num_processes : int, optional
+            An integer that determines how many different concurrent
+            tasks are allowed. We expect spark.task.gpus = 1 for GPU-enabled training. Default
+            should be 1; we don't want to invoke multiple cores/gpus without explicit mention.
+        local_mode : bool, optional
+            A boolean that determines whether we are using the driver
+            node for training. Default should be false; we don't want to invoke executors without
+            explicit mention.
+        use_gpu : bool, optional
+            A boolean that indicates whether or not we are doing training
+            on the GPU. Note that there are differences in how GPU-enabled code looks like and
+            how CPU-specific code looks like.
+
+        Raises
+        ------
+        ValueError
+            If any of the parameters are incorrect.
+        RuntimeError
+            If an active SparkSession is unavailable.
+        """
+        # TODO(SPARK-41915): Remove framework requirement in a future PR.
+        super().__init__(num_processes, local_mode, use_gpu)
+        self.framework = framework
+        self.ssl_conf = "pytorch.spark.distributor.ignoreSsl"  # type: ignore
+        self._validate_input_params()
+
+    def _validate_input_params(self) -> None:
+        """Validates input params
+
+        Raises
+        ------
+        ValueError
+            Thrown when user fails to provide correct input params
+        """
+        super()._validate_input_params()
+        if self.framework not in self.available_frameworks:
+            raise ValueError(
+                f"{self.framework} is not a valid framework."
+                f"Available frameworks: {self.available_frameworks}"
+            )
+
+    def run(self, train_object: Union[Callable, str], *args: Any) -> Optional[Any]:
+        """Runs distributed training.
+
+        Parameters
+        ----------
+        train_object : callable object or str
+            Either a PyTorch/PyTorch Lightning training function or the path to a python file
+            that launches distributed training.
+        args : *args
+            The arguments for train_object
+
+        Returns
+        -------
+        Optional[Any]

Review Comment:
   ```suggestion
   ```



##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,193 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# from pyspark.testing.sqlutils import ReusedSQLTestCase
+import os

Review Comment:
   ```suggestion
   import os
   
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063202455


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,193 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import stat
+import tempfile
+import unittest
+
+from pyspark import SparkConf, SparkContext
+from pyspark.ml.torch.distributor import TorchDistributor
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+
+
+class TorchDistributorBaselineUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        conf = SparkConf()
+        self.sc = SparkContext("local[4]", conf=conf)
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+    def test_validate_correct_inputs(self) -> None:
+        inputs = [
+            ("pytorch", 1, True, False),
+            ("pytorch", 100, True, False),
+            ("pytorch-lightning", 1, False, False),
+            ("pytorch-lightning", 100, False, False),
+        ]
+        for framework, num_processes, local_mode, use_gpu in inputs:
+            with self.subTest():
+                TorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_validate_incorrect_inputs(self) -> None:
+        inputs = [
+            ("tensorflow", 1, True, False, ValueError, "framework"),
+            ("pytroch", 100, True, False, ValueError, "framework"),
+            ("pytorchlightning", 1, False, False, ValueError, "framework"),
+            ("pytorch-lightning", 0, False, False, ValueError, "positive"),
+        ]
+        for framework, num_processes, local_mode, use_gpu, error, message in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(error, message):
+                    TorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_encryption_passes(self) -> None:
+        inputs = [
+            ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "true"),
+            ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "false"),
+            ("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "true"),
+        ]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
+            with self.subTest():
+                self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                distributor = TorchDistributor("pytorch", 1, True, False)
+                distributor._check_encryption()
+
+    def test_encryption_fails(self) -> None:
+        # this is the only combination that should fail
+        inputs = [("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "false")]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(Exception, "encryption"):
+                    self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                    self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                    distributor = TorchDistributor("pytorch", 1, True, False)
+                    distributor._check_encryption()
+
+    def test_get_num_tasks_fails(self) -> None:
+        inputs = [1, 5, 4]
+
+        # This is when the conf isn't set and we request GPUs
+        for num_processes in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(RuntimeError, "driver"):
+                    TorchDistributor("pytorch", num_processes, True, True)
+                with self.assertRaisesRegex(RuntimeError, "unset"):
+                    TorchDistributor("pytorch", num_processes, False, True)
+
+
+class TorchDistributorLocalUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.tempFile = tempfile.NamedTemporaryFile(delete=False)
+        self.tempFile.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.tempFile.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.tempFile.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+        conf = SparkConf().set("spark.test.home", SPARK_HOME)
+
+        conf = conf.set("spark.driver.resource.gpu.amount", "3")
+        conf = conf.set("spark.driver.resource.gpu.discoveryScript", self.tempFile.name)
+
+        self.sc = SparkContext("local-cluster[2,2,1024]", class_name, conf=conf)
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self) -> None:
+        os.unlink(self.tempFile.name)
+        self.spark.stop()
+
+    def test_get_num_tasks_locally(self) -> None:
+        succeeds = [1, 2]
+        fails = [4, 8]
+        for num_processes in succeeds:
+            with self.subTest():
+                expected_output = num_processes
+                distributor = TorchDistributor("pytorch", num_processes, True, True)
+                self.assertEqual(distributor._get_num_tasks(), expected_output)
+
+        for num_processes in fails:
+            with self.subTest():
+                with self.assertWarns(RuntimeWarning):
+                    distributor = TorchDistributor("pytorch", num_processes, True, True)
+                    distributor.num_processes = 3
+
+
+class TorchDistributorDistributedUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        self.tempFile = tempfile.NamedTemporaryFile(delete=False)
+        self.tempFile.write(
+            b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\",\\"1\\",\\"2\\"]}'
+        )
+        self.tempFile.close()
+        # create temporary directory for Worker resources coordination
+        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(self.tempdir.name)
+        os.chmod(
+            self.tempFile.name,
+            stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH,
+        )
+        conf = SparkConf().set("spark.test.home", SPARK_HOME)
+
+        conf = conf.set("spark.worker.resource.gpu.discoveryScript", self.tempFile.name)
+        conf = conf.set("spark.worker.resource.gpu.amount", "3")
+        conf = conf.set("spark.task.cpus", "2")
+        conf = conf.set("spark.task.resource.gpu.amount", "1")
+        conf = conf.set("spark.executor.resource.gpu.amount", "1")
+
+        self.sc = SparkContext("local-cluster[2,2,1024]", class_name, conf=conf)
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self) -> None:
+        os.unlink(self.tempFile.name)
+        self.spark.stop()
+
+    def test_get_num_tasks_distributed(self) -> None:
+        inputs = [(1, 8, 8), (2, 8, 4), (3, 8, 3)]
+
+        for spark_conf_value, num_processes, expected_output in inputs:
+            with self.subTest():
+                self.spark.sparkContext._conf.set(
+                    "spark.task.resource.gpu.amount", str(spark_conf_value)
+                )
+                distributor = TorchDistributor("pytorch", num_processes, False, True)
+                self.assertEqual(distributor._get_num_tasks(), expected_output)
+
+        self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount", "1")
+
+
+if __name__ == "__main__":
+    from pyspark.ml.torch.tests.test_distributor import *  # noqa: F401,F403 type: ignore
+
+    try:
+        import xmlrunner  # type: ignore

Review Comment:
   And also not seeing `# type: ignore` in other test files.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
WeichenXu123 commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063068475


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))

Review Comment:
   if `num_available_gpus` == 0, we should raise error saying GPU resources is not configured properly ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063039894


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:

Review Comment:
   Why not just use something like
   ```
   bool(sc.getConf().get(key, default_value))
   ```
   and let Python handle the issue?
   Then we don't need to create an function.
   We are doing similar things for `task_gpu_amount`.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1055997705


##########
dev/sparktestsupport/modules.py:
##########
@@ -617,6 +617,7 @@ def __hash__(self):
         "pyspark.ml.tests.test_tuning",
         "pyspark.ml.tests.test_util",
         "pyspark.ml.tests.test_wrapper",
+        "pyspark.ml.torch.tests.test_distributor"

Review Comment:
   ```suggestion
           "pyspark.ml.torch.tests.test_distributor",
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1053897964


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,82 @@
+from pyspark.testing.utils import PySparkTestCase
+from pyspark.ml.torch.distributor import PyTorchDistributor
+from pyspark.sql import SparkSession
+
+# Q: Do you recommend that I use pytest.mark.parametrize? It doesn't seem to be used elsewhere in this code...
+class TestPyTorchDistributor(PySparkTestCase):
+
+    def setUp(self):
+        super().setUp()
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self):
+        super().tearDown()
+
+    def test_validate_correct_inputs(self):
+        inputs = [("pytorch", 1, True, True),
+                  ("pytorch", 100, True, False),
+                  ("pytorch-lightning", 1, False, True),
+                  ("pytorch-lightning", 100, False, False)]
+        for framework, num_processes, local_mode, use_gpu in inputs:
+            with self.subTest():
+                PyTorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_validate_incorrect_inputs(self):
+        inputs = [("tensorflow", 1, True, True, ValueError, "framework"),
+                  ("pytroch", 100, True, False, ValueError, "framework"),
+                  ("pytorchlightning", 1, False, True, ValueError, "framework"),
+                  ("pytorch-lightning", 0, False, False, ValueError, "positive")]
+        for framework, num_processes, local_mode, use_gpu, error, message in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(error, message):
+                    PyTorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_get_correct_num_tasks_when_spark_conf_is_set(self):
+        inputs = [(1, 8, 8),
+                  (2, 8, 4),
+                  (3, 8, 3)]
+        # this is when the sparkconf isn't set
+        for _, num_processes, _ in inputs:
+            with self.subTest():
+                distributor = PyTorchDistributor("pytorch", num_processes, True, True)
+                self.assertEqual(distributor._get_num_tasks(), num_processes)
+        
+        # this is when the sparkconf is set
+        for spark_conf_value, num_processes, expected_output in inputs:
+            with self.subTest():
+                self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount", spark_conf_value)
+                distributor = PyTorchDistributor("pytorch", num_processes, True, True)
+                self.assertEqual(distributor._get_num_tasks(), expected_output)
+
+    def test_encryption_passes(self):
+        input_combination = [("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "true"),
+                             ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "false"),
+                             ("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "true")]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in input_combination:
+            with self.subTest():
+                self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                distributor = PyTorchDistributor("pytorch", 1, True, True)
+                distributor._check_encryption()
+
+    def test_encryption_fails(self):
+        # this is the only combination that should fail
+        input_combination = [("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "false")]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in input_combination:
+            with self.subTest():
+                with self.assertRaisesRegex(Exception, "encryption"):
+                    self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                    self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                    distributor = PyTorchDistributor("pytorch", 1, True, True)
+                    distributor._check_encryption()
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_util import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)

Review Comment:
   Should probably add this file into https://github.com/apache/spark/blob/master/dev/sparktestsupport/modules.py#L605



##########
python/pyspark/ml/torch/utils.py:
##########
@@ -0,0 +1,24 @@
+from pyspark.context import SparkContext

Review Comment:
   license header missing



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1053893899


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .utils import get_conf_boolean
+import math
+from typing import Union, Callable
+from pyspark.sql import SparkSession
+
+
+# might need to move into its own file as we look forward.
+class Distributor:
+    def __init__(self, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        
+    def _get_num_tasks(self):
+        """
+        Returns:
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                task_gpu_amount = int(self.sc.getConf().get(key))
+            else:
+                task_gpu_amount = 1 # for single node clusters
+            if task_gpu_amount < 1:
+                raise ValueError(
+                    f"The Spark conf `{key}` has a value "
+                    f"of {task_gpu_amount} but it "
+                    "should not have a value less than 1."
+                )
+            return math.ceil(self.num_processes / task_gpu_amount)
+        return self.num_processes
+    
+    def _validate_input_params(self):
+        if self.num_processes <= 0:
+            raise ValueError(f"num_proccesses has to be a positive integer")
+
+    # can be required for TF distributor in the future as well
+    def _check_encryption(self):
+        if "ssl_conf" not in self.__dict__:
+            raise NotImplementedError()
+        is_ssl_enabled = get_conf_boolean(
+            self.sc,
+            "spark.ssl.enabled",
+            "false"
+        )
+        ignore_ssl = get_conf_boolean(
+            self.sc,
+            self.ssl_conf,
+            "false"
+        )
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                print(

Review Comment:
   There will be a logger in the next PR. For now, just leaving it as a print() function.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
WeichenXu123 commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1058883692


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# Moved the util functions to this file for now
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                if gpu_amount_raw := self.sc.getConf().get(key):  # mypy error??
+                    task_gpu_amount = int(gpu_amount_raw)
+            else:
+                task_gpu_amount = 1  # for single node clusters
+            if task_gpu_amount < 1:
+                raise ValueError(
+                    f"The Spark conf `{key}` has a value "
+                    f"of {task_gpu_amount} but it "
+                    "should not have a value less than 1."
+                )
+            return math.ceil(self.num_processes / task_gpu_amount)

Review Comment:
   Q:
   What if spark config "spark.task.resource.gpu.amount" > 1 but the user provided training code can use only one GPU core per task ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on PR #39146:
URL: https://github.com/apache/spark/pull/39146#issuecomment-1363559268

   cc @WeichenXu123 and @mengxr 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1055794843


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .utils import get_conf_boolean
+import math
+from typing import Union, Callable
+from pyspark.sql import SparkSession
+
+
+# might need to move into its own file as we look forward.
+class Distributor:
+    def __init__(self, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        self.spark = SparkSession.builder.getOrCreate()

Review Comment:
   There actually might be (specificially when users want to set `spark.task.resource.gpu.amount > 1`), let me fix this by adding a new parameter as input.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063148764


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,307 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    RuntimeError
+        Thrown when the conf value is not a valid boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise RuntimeError(
+        "get_conf_boolean expected a boolean conf "

Review Comment:
   I think we should show `key` instead of `get_conf_boolean`. Actually also let's throw `ValueError`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063015529


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    >>>     import torch.distributed
+    >>>     torch.distributed.init_process_group(backend="nccl")
+    >>>     ...
+    >>>     torch.destroy_process_group()
+    >>>     return model # or anything else
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=True,
+                                  use_gpu=True)
+    >>> model = distributor.run(train, 1e-3)
+
+    Run PyTorch Training on GPU (using a file with PyTorch code)
+
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=False,
+                                  use_gpu=True)
+    >>> distributor.run("/path/to/train.py", *args)
+
+    Run PyTorch Lightning Training
+    >>> num_proc = 2
+    >>> def train():
+    >>>     from pytorch_lightning import Trainer
+    >>>     ...
+    >>>     # required to set devices = 1 and num_nodes == num_processes
+    >>>     trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp")
+    >>>     trainer.fit()
+    >>>     ...
+    >>>     return trainer
+    >>> distributor = TorchDistributor(framework="pytorch-lightning",
+                                  num_processes=num_proc,
+                                  local_mode=True,
+                                  use_gpu=True)
+    >>> trainer = distributor.run(train)
+    """
+
+    available_frameworks = ["pytorch", "pytorch-lightning"]
+
+    def __init__(
+        self,
+        framework: str,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        """Initializes the distributor.
+
+        Parameters
+        ----------
+        framework : str
+            A string indicating whether or not we are using PyTorch or PyTorch
+            Lightning. This could either be the string “pytorch” or ”pytorch-lightning”.
+        num_processes : int, optional
+            An integer that determines how many different concurrent
+            tasks are allowed. We expect spark.task.gpus = 1 for GPU-enabled training. Default
+            should be 1; we don't want to invoke multiple cores/gpus without explicit mention.
+        local_mode : bool, optional
+            A boolean that determines whether we are using the driver
+            node for training. Default should be false; we don't want to invoke executors without
+            explicit mention.
+        use_gpu : bool, optional
+            A boolean that indicates whether or not we are doing training
+            on the GPU. Note that there are differences in how GPU-enabled code looks like and
+            how CPU-specific code looks like.
+        spark : Optional[SparkSession], optional
+            An optional parameter that allows users to pass in a custom SparkSession argument
+            with a custom conf, by default None
+
+        Raises
+        ------
+        ValueError
+            If any of the parameters are incorrect.
+        """
+        # TODO(SPARK-41915): Remove framework requirement in a future PR.
+        super().__init__(num_processes, local_mode, use_gpu, spark)
+        self.framework = framework
+        self.ssl_conf = "pytorch.spark.distributor.ignoreSsl"  # type: ignore
+        self._validate_input_params()
+
+    def _validate_input_params(self) -> None:
+        """Validates input params
+
+        Raises
+        ------
+        ValueError
+            Thrown when user fails to provide correct input params
+        """
+        super()._validate_input_params()
+        if self.framework not in self.available_frameworks:
+            raise ValueError(
+                f"{self.framework} is not a valid framework."
+                f"Available frameworks: {self.available_frameworks}"
+            )
+
+    def run(self, train_fn: Union[Callable, str], *args: Union[str, int]) -> Union[None, Any]:
+        """Runs distributed training.
+
+        Parameters
+        ----------
+        train_fn : Union[Callable, str]

Review Comment:
   per numpydoc (https://numpydoc.readthedocs.io/en/latest/format.html#parameters)
   
   ```suggestion
           train_fn : callable object or str
   ```
   
   Would have to fix all places in this PR.



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    >>>     import torch.distributed
+    >>>     torch.distributed.init_process_group(backend="nccl")
+    >>>     ...
+    >>>     torch.destroy_process_group()
+    >>>     return model # or anything else
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=True,
+                                  use_gpu=True)
+    >>> model = distributor.run(train, 1e-3)
+
+    Run PyTorch Training on GPU (using a file with PyTorch code)
+
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=False,
+                                  use_gpu=True)
+    >>> distributor.run("/path/to/train.py", *args)
+
+    Run PyTorch Lightning Training
+    >>> num_proc = 2
+    >>> def train():
+    >>>     from pytorch_lightning import Trainer
+    >>>     ...
+    >>>     # required to set devices = 1 and num_nodes == num_processes
+    >>>     trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp")
+    >>>     trainer.fit()
+    >>>     ...
+    >>>     return trainer
+    >>> distributor = TorchDistributor(framework="pytorch-lightning",
+                                  num_processes=num_proc,
+                                  local_mode=True,
+                                  use_gpu=True)
+    >>> trainer = distributor.run(train)
+    """
+
+    available_frameworks = ["pytorch", "pytorch-lightning"]
+
+    def __init__(
+        self,
+        framework: str,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        """Initializes the distributor.
+
+        Parameters
+        ----------
+        framework : str
+            A string indicating whether or not we are using PyTorch or PyTorch
+            Lightning. This could either be the string “pytorch” or ”pytorch-lightning”.
+        num_processes : int, optional
+            An integer that determines how many different concurrent
+            tasks are allowed. We expect spark.task.gpus = 1 for GPU-enabled training. Default
+            should be 1; we don't want to invoke multiple cores/gpus without explicit mention.
+        local_mode : bool, optional
+            A boolean that determines whether we are using the driver
+            node for training. Default should be false; we don't want to invoke executors without
+            explicit mention.
+        use_gpu : bool, optional
+            A boolean that indicates whether or not we are doing training
+            on the GPU. Note that there are differences in how GPU-enabled code looks like and
+            how CPU-specific code looks like.
+        spark : Optional[SparkSession], optional
+            An optional parameter that allows users to pass in a custom SparkSession argument
+            with a custom conf, by default None
+
+        Raises
+        ------
+        ValueError
+            If any of the parameters are incorrect.
+        """
+        # TODO(SPARK-41915): Remove framework requirement in a future PR.
+        super().__init__(num_processes, local_mode, use_gpu, spark)
+        self.framework = framework
+        self.ssl_conf = "pytorch.spark.distributor.ignoreSsl"  # type: ignore
+        self._validate_input_params()
+
+    def _validate_input_params(self) -> None:
+        """Validates input params
+
+        Raises
+        ------
+        ValueError
+            Thrown when user fails to provide correct input params
+        """
+        super()._validate_input_params()
+        if self.framework not in self.available_frameworks:
+            raise ValueError(
+                f"{self.framework} is not a valid framework."
+                f"Available frameworks: {self.available_frameworks}"
+            )
+
+    def run(self, train_fn: Union[Callable, str], *args: Union[str, int]) -> Union[None, Any]:

Review Comment:
   ```suggestion
       def run(self, train_fn: Union[Callable, str], *args: Union[str, int]) -> Optional[Any]:
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063014721


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    >>>     import torch.distributed
+    >>>     torch.distributed.init_process_group(backend="nccl")
+    >>>     ...
+    >>>     torch.destroy_process_group()
+    >>>     return model # or anything else
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=True,
+                                  use_gpu=True)
+    >>> model = distributor.run(train, 1e-3)
+
+    Run PyTorch Training on GPU (using a file with PyTorch code)
+
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=False,
+                                  use_gpu=True)
+    >>> distributor.run("/path/to/train.py", *args)
+
+    Run PyTorch Lightning Training
+    >>> num_proc = 2
+    >>> def train():
+    >>>     from pytorch_lightning import Trainer
+    >>>     ...
+    >>>     # required to set devices = 1 and num_nodes == num_processes
+    >>>     trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp")
+    >>>     trainer.fit()
+    >>>     ...
+    >>>     return trainer

Review Comment:
   ```suggestion
       >>> def train():
       ...     from pytorch_lightning import Trainer
       ...     # ...
       ...     # required to set devices = 1 and num_nodes == num_processes
       ...     trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp")
       ...     trainer.fit()
       ...     # ...
       ...     return trainer
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063017417


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,195 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# from pyspark.testing.sqlutils import ReusedSQLTestCase
+import os
+from pyspark import SparkConf, SparkContext
+from pyspark.ml.torch.distributor import TorchDistributor
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+import stat
+import tempfile
+import unittest
+
+
+class TorchDistributorBaselineUnitTests(unittest.TestCase):

Review Comment:
   Does this test fail if. `pytorch` is not installed? If so, we should skip if that's not installed. See `pyspark.testing.sqlutils` and `git grep -r "@unittest.skipIf"`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063156580


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,307 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    RuntimeError
+        Thrown when the conf value is not a valid boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise RuntimeError(
+        "get_conf_boolean expected a boolean conf "

Review Comment:
   We use JVM error actually. That is usually `IllegalArgumentException` and I believe `ValueError` is closest to that



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063014363


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(

Review Comment:
   Will make this a `RuntimeError`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063008896


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,195 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# from pyspark.testing.sqlutils import ReusedSQLTestCase
+import os
+from pyspark import SparkConf, SparkContext
+from pyspark.ml.torch.distributor import TorchDistributor
+from pyspark.sql import SparkSession
+from pyspark.testing.utils import SPARK_HOME
+import stat
+import tempfile
+import unittest
+
+
+class TorchDistributorBaselineUnitTests(unittest.TestCase):
+    def setUp(self) -> None:
+        conf = SparkConf()
+        self.sc = SparkContext("local[4]", conf=conf)
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+        self.sc.stop()
+
+    def test_validate_correct_inputs(self) -> None:
+        inputs = [
+            ("pytorch", 1, True, False),
+            ("pytorch", 100, True, False),
+            ("pytorch-lightning", 1, False, False),
+            ("pytorch-lightning", 100, False, False),
+        ]
+        for framework, num_processes, local_mode, use_gpu in inputs:
+            with self.subTest():
+                TorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_validate_incorrect_inputs(self) -> None:
+        inputs = [
+            ("tensorflow", 1, True, False, ValueError, "framework"),
+            ("pytroch", 100, True, False, ValueError, "framework"),
+            ("pytorchlightning", 1, False, False, ValueError, "framework"),
+            ("pytorch-lightning", 0, False, False, ValueError, "positive"),
+        ]
+        for framework, num_processes, local_mode, use_gpu, error, message in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(error, message):
+                    TorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_encryption_passes(self) -> None:
+        inputs = [
+            ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "true"),
+            ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "false"),
+            ("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "true"),
+        ]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
+            with self.subTest():
+                self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                distributor = TorchDistributor("pytorch", 1, True, False)
+                distributor._check_encryption()
+
+    def test_encryption_fails(self) -> None:
+        # this is the only combination that should fail
+        inputs = [("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "false")]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(Exception, "encryption"):
+                    self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                    self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                    distributor = TorchDistributor("pytorch", 1, True, False)
+                    distributor._check_encryption()
+
+    def test_get_num_tasks_fails(self) -> None:
+        inputs = [1, 5, 4]
+
+        # This is when the conf isn't set and we request GPUs
+        for num_processes in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(RuntimeError, "driver"):
+                    TorchDistributor("pytorch", num_processes, True, True)
+                with self.assertRaisesRegex(RuntimeError, "unset"):
+                    TorchDistributor("pytorch", num_processes, False, True)
+
+
+class TorchDistributorLocalUnitTests(unittest.TestCase):

Review Comment:
   More tests will be added to this class and the following class in later PRs.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063043172


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    >>>     import torch.distributed
+    >>>     torch.distributed.init_process_group(backend="nccl")
+    >>>     ...
+    >>>     torch.destroy_process_group()
+    >>>     return model # or anything else
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=True,
+                                  use_gpu=True)
+    >>> model = distributor.run(train, 1e-3)
+
+    Run PyTorch Training on GPU (using a file with PyTorch code)
+
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=False,
+                                  use_gpu=True)
+    >>> distributor.run("/path/to/train.py", *args)
+
+    Run PyTorch Lightning Training
+    >>> num_proc = 2
+    >>> def train():
+    >>>     from pytorch_lightning import Trainer
+    >>>     ...
+    >>>     # required to set devices = 1 and num_nodes == num_processes
+    >>>     trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp")
+    >>>     trainer.fit()
+    >>>     ...
+    >>>     return trainer
+    >>> distributor = TorchDistributor(framework="pytorch-lightning",
+                                  num_processes=num_proc,
+                                  local_mode=True,
+                                  use_gpu=True)
+    >>> trainer = distributor.run(train)
+    """
+
+    available_frameworks = ["pytorch", "pytorch-lightning"]

Review Comment:
   Based on the discussion with Xiangrui, do we still need this framework?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063040647


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(

Review Comment:
   How about throwing an warning message and set `self.num_processes` to `num_available_gpus`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on PR #39146:
URL: https://github.com/apache/spark/pull/39146#issuecomment-1364629485

   @rithwik-db  could you please fix the python lint?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1053898774


##########
python/pyspark/ml/__init__.py:
##########
@@ -66,4 +69,5 @@
     "util",
     "linalg",
     "param",
+    "PyTorchDistributor"

Review Comment:
   should list at https://github.com/apache/spark/blob/master/python/docs/source/reference/pyspark.ml.rst so it can be listed there.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1055997534


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,179 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+# Moved the util functions to this file for now
+# TODO: will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str):
+        """
+        Get the conf "key" from the given spark context,
+        or return the default value if the conf is not set.
+        This expects the conf value to be a boolean or string;
+        if the value is a string, this checks for all capitalization
+        patterns of "true" and "false" to match Scala.
+        Args:
+            key: string for conf name
+            default_value: default value for the conf value for the given key
+        """
+        val = sc.getConf().get(key, default_value)
+        lowercase_val = val.lower()
+        if lowercase_val == "true":
+            return True
+        if lowercase_val == "false":
+            return False
+        raise Exception(
+            "get_conf_boolean expected a boolean conf "
+            "value but found value of type {} "
+            "with value: {}".format(type(val), val)
+        )
+
+# might need to move into its own file as we look forward.
+class Distributor:
+    def __init__(self, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True, spark: Optional[SparkSession] = None):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        
+    def _get_num_tasks(self):
+        """
+        Returns:
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                task_gpu_amount = int(self.sc.getConf().get(key))
+            else:
+                task_gpu_amount = 1 # for single node clusters
+            if task_gpu_amount < 1:
+                raise ValueError(
+                    f"The Spark conf `{key}` has a value "
+                    f"of {task_gpu_amount} but it "
+                    "should not have a value less than 1."
+                )
+            return math.ceil(self.num_processes / task_gpu_amount)
+        return self.num_processes
+    
+    def _validate_input_params(self):
+        if self.num_processes <= 0:
+            raise ValueError(f"num_proccesses has to be a positive integer")
+
+    def _check_encryption(self):
+        if "ssl_conf" not in self.__dict__:
+            raise NotImplementedError()
+        is_ssl_enabled = get_conf_boolean(
+            self.sc,
+            "spark.ssl.enabled",
+            "false"
+        )
+        ignore_ssl = get_conf_boolean(
+            self.sc,
+            self.ssl_conf,
+            "false"
+        )
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+class PyTorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    TODO: Add examples in the future

Review Comment:
   Let's probably file a JIRA and make this todo as `TODO(SPARK-XXXXX): blah blah` to avoid for the todos to be forgotten :-).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1054110428


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,82 @@
+from pyspark.testing.utils import PySparkTestCase
+from pyspark.ml.torch.distributor import PyTorchDistributor
+from pyspark.sql import SparkSession
+
+# Q: Do you recommend that I use pytest.mark.parametrize? It doesn't seem to be used elsewhere in this code...
+class TestPyTorchDistributor(PySparkTestCase):
+
+    def setUp(self):
+        super().setUp()
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self):
+        super().tearDown()
+
+    def test_validate_correct_inputs(self):
+        inputs = [("pytorch", 1, True, True),
+                  ("pytorch", 100, True, False),
+                  ("pytorch-lightning", 1, False, True),
+                  ("pytorch-lightning", 100, False, False)]
+        for framework, num_processes, local_mode, use_gpu in inputs:
+            with self.subTest():
+                PyTorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_validate_incorrect_inputs(self):
+        inputs = [("tensorflow", 1, True, True, ValueError, "framework"),
+                  ("pytroch", 100, True, False, ValueError, "framework"),
+                  ("pytorchlightning", 1, False, True, ValueError, "framework"),
+                  ("pytorch-lightning", 0, False, False, ValueError, "positive")]
+        for framework, num_processes, local_mode, use_gpu, error, message in inputs:
+            with self.subTest():
+                with self.assertRaisesRegex(error, message):
+                    PyTorchDistributor(framework, num_processes, local_mode, use_gpu)
+
+    def test_get_correct_num_tasks_when_spark_conf_is_set(self):
+        inputs = [(1, 8, 8),
+                  (2, 8, 4),
+                  (3, 8, 3)]
+        # this is when the sparkconf isn't set
+        for _, num_processes, _ in inputs:
+            with self.subTest():
+                distributor = PyTorchDistributor("pytorch", num_processes, True, True)
+                self.assertEqual(distributor._get_num_tasks(), num_processes)
+        
+        # this is when the sparkconf is set
+        for spark_conf_value, num_processes, expected_output in inputs:
+            with self.subTest():
+                self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount", spark_conf_value)
+                distributor = PyTorchDistributor("pytorch", num_processes, True, True)
+                self.assertEqual(distributor._get_num_tasks(), expected_output)
+
+    def test_encryption_passes(self):
+        input_combination = [("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "true"),
+                             ("spark.ssl.enabled", "false", "pytorch.spark.distributor.ignoreSsl", "false"),
+                             ("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "true")]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in input_combination:
+            with self.subTest():
+                self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                distributor = PyTorchDistributor("pytorch", 1, True, True)
+                distributor._check_encryption()
+
+    def test_encryption_fails(self):
+        # this is the only combination that should fail
+        input_combination = [("spark.ssl.enabled", "true", "pytorch.spark.distributor.ignoreSsl", "false")]
+        for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value in input_combination:
+            with self.subTest():
+                with self.assertRaisesRegex(Exception, "encryption"):
+                    self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
+                    self.spark.sparkContext._conf.set(pytorch_conf_key, pytorch_conf_value)
+                    distributor = PyTorchDistributor("pytorch", 1, True, True)
+                    distributor._check_encryption()
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_util import *  # noqa: F401

Review Comment:
   ```suggestion
       from pyspark.ml.torch.tests.test_distributor import *  # noqa: F401
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1059095881


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# Moved the util functions to this file for now
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                if gpu_amount_raw := self.sc.getConf().get(key):  # mypy error??
+                    task_gpu_amount = int(gpu_amount_raw)
+            else:
+                task_gpu_amount = 1  # for single node clusters
+            if task_gpu_amount < 1:
+                raise ValueError(
+                    f"The Spark conf `{key}` has a value "
+                    f"of {task_gpu_amount} but it "
+                    "should not have a value less than 1."
+                )
+            return math.ceil(self.num_processes / task_gpu_amount)

Review Comment:
   The few situations:
   1. spark.task.resource.gpu.amount > 1 + user code can use 1 gpu core: This is fine, since for torchrun, we could do `torchrun nproc_per_node=task_gpu_amount ...` as the command to be executed. This will launch multiple training processes per executor.
   2. spark.task.resource.gpu.amount = 1 + user code can use > 1 gpu core: This means using [model-parallel](https://pytorch.org/tutorials/intermediate/model_parallel_tutorial.html) instead of [distributed-data-parallel](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html). It seems that currently, distributed training only supports data parallel and not model parallel so users would hopefully know not to use model parallel. We could log this as well though.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1054114029


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .utils import get_conf_boolean
+import math
+from typing import Union, Callable
+from pyspark.sql import SparkSession
+
+
+# might need to move into its own file as we look forward.
+class Distributor:
+    def __init__(self, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        
+    def _get_num_tasks(self):
+        """
+        Returns:
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                task_gpu_amount = int(self.sc.getConf().get(key))
+            else:
+                task_gpu_amount = 1 # for single node clusters
+            if task_gpu_amount < 1:
+                raise ValueError(
+                    f"The Spark conf `{key}` has a value "
+                    f"of {task_gpu_amount} but it "
+                    "should not have a value less than 1."
+                )
+            return math.ceil(self.num_processes / task_gpu_amount)
+        return self.num_processes
+    
+    def _validate_input_params(self):
+        if self.num_processes <= 0:
+            raise ValueError(f"num_proccesses has to be a positive integer")
+
+    # can be required for TF distributor in the future as well
+    def _check_encryption(self):
+        if "ssl_conf" not in self.__dict__:
+            raise NotImplementedError()
+        is_ssl_enabled = get_conf_boolean(
+            self.sc,
+            "spark.ssl.enabled",
+            "false"
+        )
+        ignore_ssl = get_conf_boolean(
+            self.sc,
+            self.ssl_conf,
+            "false"
+        )
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                print(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+class PyTorchDistributor(Distributor):
+    available_frameworks = ["pytorch", "pytorch-lightning"]
+    PICKLED_FUNC_FILE = "func.pickle"
+    TRAIN_FILE = "train.py"
+    
+    def __init__(self, framework: str, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True):

Review Comment:
   please add document



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1055997824


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,82 @@
+from pyspark.testing.utils import PySparkTestCase
+from pyspark.ml.torch.distributor import PyTorchDistributor
+from pyspark.sql import SparkSession
+
+# Q: Do you recommend that I use pytest.mark.parametrize? It doesn't seem to be used elsewhere in this code...
+class TestPyTorchDistributor(PySparkTestCase):
+
+    def setUp(self):
+        super().setUp()
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self):
+        super().tearDown()

Review Comment:
   ```suggestion
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1056653132


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,179 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+# Moved the util functions to this file for now
+# TODO: will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str):
+        """
+        Get the conf "key" from the given spark context,
+        or return the default value if the conf is not set.
+        This expects the conf value to be a boolean or string;
+        if the value is a string, this checks for all capitalization
+        patterns of "true" and "false" to match Scala.
+        Args:
+            key: string for conf name
+            default_value: default value for the conf value for the given key
+        """
+        val = sc.getConf().get(key, default_value)
+        lowercase_val = val.lower()
+        if lowercase_val == "true":
+            return True
+        if lowercase_val == "false":
+            return False
+        raise Exception(
+            "get_conf_boolean expected a boolean conf "
+            "value but found value of type {} "
+            "with value: {}".format(type(val), val)
+        )
+
+# might need to move into its own file as we look forward.
+class Distributor:
+    def __init__(self, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True, spark: Optional[SparkSession] = None):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        
+    def _get_num_tasks(self):
+        """
+        Returns:
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                task_gpu_amount = int(self.sc.getConf().get(key))
+            else:
+                task_gpu_amount = 1 # for single node clusters
+            if task_gpu_amount < 1:
+                raise ValueError(
+                    f"The Spark conf `{key}` has a value "
+                    f"of {task_gpu_amount} but it "
+                    "should not have a value less than 1."
+                )
+            return math.ceil(self.num_processes / task_gpu_amount)
+        return self.num_processes
+    
+    def _validate_input_params(self):
+        if self.num_processes <= 0:
+            raise ValueError(f"num_proccesses has to be a positive integer")
+
+    def _check_encryption(self):
+        if "ssl_conf" not in self.__dict__:
+            raise NotImplementedError()
+        is_ssl_enabled = get_conf_boolean(
+            self.sc,
+            "spark.ssl.enabled",
+            "false"
+        )
+        ignore_ssl = get_conf_boolean(
+            self.sc,
+            self.ssl_conf,
+            "false"
+        )
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+class PyTorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    TODO: Add examples in the future

Review Comment:
   I can actually just add some examples for now



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1061232558


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# Moved the util functions to this file for now
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                if gpu_amount_raw := self.sc.getConf().get(key):  # mypy error??
+                    task_gpu_amount = int(gpu_amount_raw)
+            else:
+                task_gpu_amount = 1  # for single node clusters

Review Comment:
   +1



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] AmplabJenkins commented on pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
AmplabJenkins commented on PR #39146:
URL: https://github.com/apache/spark/pull/39146#issuecomment-1363277090

   Can one of the admins verify this patch?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1053894446


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,82 @@
+from pyspark.testing.utils import PySparkTestCase
+from pyspark.ml.torch.distributor import PyTorchDistributor
+from pyspark.sql import SparkSession
+
+# Q: Do you recommend that I use pytest.mark.parametrize? It doesn't seem to be used elsewhere in this code...
+class TestPyTorchDistributor(PySparkTestCase):
+
+    def setUp(self):
+        super().setUp()
+        self.spark = SparkSession(self.sc)
+
+    def tearDown(self):
+        super().tearDown()
+
+    def test_validate_correct_inputs(self):

Review Comment:
   What is the standard for testing multiple inputs on one function. I didn't find any uses of `parametrize` so is there an alternative approach? I saw that `self.subTest()` has been used in the past so started off with that.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063014127


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    >>>     import torch.distributed
+    >>>     torch.distributed.init_process_group(backend="nccl")
+    >>>     ...
+    >>>     torch.destroy_process_group()
+    >>>     return model # or anything else
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=True,
+                                  use_gpu=True)

Review Comment:
   ```suggestion
       >>> distributor = TorchDistributor(
       ...     framework="pytorch",
       ...     num_processes=2,
       ...     local_mode=True,
       ...     use_gpu=True)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063014518


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    >>>     import torch.distributed
+    >>>     torch.distributed.init_process_group(backend="nccl")
+    >>>     ...
+    >>>     torch.destroy_process_group()
+    >>>     return model # or anything else

Review Comment:
   ```suggestion
       >>> def train(learning_rate):
       ...     import torch.distributed
       ...     torch.distributed.init_process_group(backend="nccl")
       ...     # ...
       ...     torch.destroy_process_group()
       ...     return model # or anything else
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063039894


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:

Review Comment:
   Why not just use something like
   ```
   bool(sc.getConf().get(key, default_value))
   ```
   and let Python handle the issue?
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063149995


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,307 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    RuntimeError
+        Thrown when the conf value is not a valid boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise RuntimeError(
+        "get_conf_boolean expected a boolean conf "

Review Comment:
   Do we use `ValueError` for conf errors or `RuntimeError`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063173105


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:

Review Comment:
   How about using some function like `distutils.util.strtobool` ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1054104494


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .utils import get_conf_boolean
+import math
+from typing import Union, Callable
+from pyspark.sql import SparkSession
+
+
+# might need to move into its own file as we look forward.
+class Distributor:

Review Comment:
   please add document, since this is the one of the main APIs.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon closed pull request #39146: [SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon closed pull request #39146: [SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes
URL: https://github.com/apache/spark/pull/39146


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
WeichenXu123 commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1058882313


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# Moved the util functions to this file for now
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                if gpu_amount_raw := self.sc.getConf().get(key):  # mypy error??
+                    task_gpu_amount = int(gpu_amount_raw)
+            else:
+                task_gpu_amount = 1  # for single node clusters

Review Comment:
   ```suggestion
   gpu_amount_raw = int(self.sc.getConf().get("spark.task.resource.gpu.amount", "1"))
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1055997322


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,179 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+# Moved the util functions to this file for now
+# TODO: will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str):
+        """
+        Get the conf "key" from the given spark context,
+        or return the default value if the conf is not set.
+        This expects the conf value to be a boolean or string;
+        if the value is a string, this checks for all capitalization
+        patterns of "true" and "false" to match Scala.
+        Args:
+            key: string for conf name
+            default_value: default value for the conf value for the given key
+        """
+        val = sc.getConf().get(key, default_value)
+        lowercase_val = val.lower()
+        if lowercase_val == "true":
+            return True
+        if lowercase_val == "false":
+            return False
+        raise Exception(
+            "get_conf_boolean expected a boolean conf "
+            "value but found value of type {} "
+            "with value: {}".format(type(val), val)
+        )
+
+# might need to move into its own file as we look forward.
+class Distributor:
+    def __init__(self, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True, spark: Optional[SparkSession] = None):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        
+    def _get_num_tasks(self):
+        """
+        Returns:
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                task_gpu_amount = int(self.sc.getConf().get(key))
+            else:
+                task_gpu_amount = 1 # for single node clusters
+            if task_gpu_amount < 1:
+                raise ValueError(
+                    f"The Spark conf `{key}` has a value "
+                    f"of {task_gpu_amount} but it "
+                    "should not have a value less than 1."
+                )
+            return math.ceil(self.num_processes / task_gpu_amount)
+        return self.num_processes
+    
+    def _validate_input_params(self):
+        if self.num_processes <= 0:
+            raise ValueError(f"num_proccesses has to be a positive integer")
+
+    def _check_encryption(self):
+        if "ssl_conf" not in self.__dict__:
+            raise NotImplementedError()
+        is_ssl_enabled = get_conf_boolean(
+            self.sc,
+            "spark.ssl.enabled",
+            "false"
+        )
+        ignore_ssl = get_conf_boolean(
+            self.sc,
+            self.ssl_conf,
+            "false"
+        )
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+class PyTorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    TODO: Add examples in the future
+    """
+    available_frameworks = ["pytorch", "pytorch-lightning"]
+    PICKLED_FUNC_FILE = "func.pickle"
+    TRAIN_FILE = "train.py"
+    
+    def __init__(self, framework: str, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True, spark: Optional[SparkSession] = None):
+        """Initializes the distributor

Review Comment:
   Should better follow NumPy documentation style (https://numpydoc.readthedocs.io/en/latest/format.html)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on PR #39146:
URL: https://github.com/apache/spark/pull/39146#issuecomment-1372985154

   Can I only request a review from one person? @HyukjinKwon


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063014284


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    >>>     import torch.distributed
+    >>>     torch.distributed.init_process_group(backend="nccl")
+    >>>     ...
+    >>>     torch.destroy_process_group()
+    >>>     return model # or anything else
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=True,
+                                  use_gpu=True)
+    >>> model = distributor.run(train, 1e-3)
+
+    Run PyTorch Training on GPU (using a file with PyTorch code)
+
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=False,
+                                  use_gpu=True)

Review Comment:
   ```suggestion
       >>> distributor = TorchDistributor(
       ...     framework="pytorch",
       ...     num_processes=2,
       ...     local_mode=False,
       ...     use_gpu=True)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063014909


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(

Review Comment:
   Will need to be `RuntimeError` I believe



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063038708


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:

Review Comment:
   @HyukjinKwon  What is the tradition for spark/pyspark handle invalid conf?
   
   I think ignore the invalid conf instead of throwing an error may be a proper way.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063186150


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:

Review Comment:
   Is this [deprecation](https://peps.python.org/pep-0632/) a relevant issue if we were to use `distutils`? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063013544


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(

Review Comment:
   Can we raise either `RuntimeError` or `ValueError`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063118380


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    >>>     import torch.distributed
+    >>>     torch.distributed.init_process_group(backend="nccl")
+    >>>     ...
+    >>>     torch.destroy_process_group()
+    >>>     return model # or anything else
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=True,
+                                  use_gpu=True)
+    >>> model = distributor.run(train, 1e-3)
+
+    Run PyTorch Training on GPU (using a file with PyTorch code)
+
+    >>> distributor = TorchDistributor(framework="pytorch",
+                                  num_processes=2,
+                                  local_mode=False,
+                                  use_gpu=True)
+    >>> distributor.run("/path/to/train.py", *args)
+
+    Run PyTorch Lightning Training
+    >>> num_proc = 2
+    >>> def train():
+    >>>     from pytorch_lightning import Trainer
+    >>>     ...
+    >>>     # required to set devices = 1 and num_nodes == num_processes
+    >>>     trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp")
+    >>>     trainer.fit()
+    >>>     ...
+    >>>     return trainer
+    >>> distributor = TorchDistributor(framework="pytorch-lightning",
+                                  num_processes=num_proc,
+                                  local_mode=True,
+                                  use_gpu=True)
+    >>> trainer = distributor.run(train)
+    """
+
+    available_frameworks = ["pytorch", "pytorch-lightning"]

Review Comment:
   Will make a change to address this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063149647


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,307 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    RuntimeError
+        Thrown when the conf value is not a valid boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise RuntimeError(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        self.spark = SparkSession.getActiveSession()
+        if not self.spark:
+            raise RuntimeError("An active SparkSession is required for the distributor.")
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+
+        Raises
+        ------
+        RuntimeError
+            Raised when the SparkConf was misconfigured.
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if num_available_gpus == 0:
+                    raise RuntimeError("GPU resources were not configured properly on the driver.")
+                if self.num_processes > num_available_gpus:
+                    warnings.warn(
+                        f"'num_processes' cannot be set to a value greater than the number of "
+                        f"available GPUs on the driver, which is {num_available_gpus}. "
+                        f"'num_processes' was reset to be equal to the number of available GPUs.",
+                        RuntimeWarning,
+                    )
+                    self.num_processes = num_available_gpus
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        RuntimeError
+            Thrown when the user requires ssl encryption or when the user initializes
+            the Distributor parent class.
+        """
+        if not "ssl_conf":
+            raise RuntimeError(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise RuntimeError(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    ...     import torch.distributed
+    ...     torch.distributed.init_process_group(backend="nccl")
+    ...     # ...
+    ...     torch.destroy_process_group()
+    ...     return model # or anything else
+    >>> distributor = TorchDistributor(
+    ...     framework="pytorch",
+    ...     num_processes=2,
+    ...     local_mode=True,
+    ...     use_gpu=True)
+    >>> model = distributor.run(train, 1e-3)
+
+    Run PyTorch Training on GPU (using a file with PyTorch code)
+
+    >>> distributor = TorchDistributor(
+    ...     framework="pytorch",
+    ...     num_processes=2,
+    ...     local_mode=False,
+    ...     use_gpu=True)
+    >>> distributor.run("/path/to/train.py", *args)
+
+    Run PyTorch Lightning Training on GPU
+
+    >>> num_proc = 2
+    >>> def train():
+    ...     from pytorch_lightning import Trainer
+    ...     # ...
+    ...     # required to set devices = 1 and num_nodes == num_processes
+    ...     trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp")
+    ...     trainer.fit()
+    ...     # ...
+    ...     return trainer
+    >>> distributor = TorchDistributor(
+    ...     framework="pytorch-lightning",
+    ...     num_processes=num_proc,
+    ...     local_mode=True,
+    ...     use_gpu=True)
+    >>> trainer = distributor.run(train)
+    """
+
+    # TODO(SPARK-41915): Remove need for setting frameworks in a future PR.
+    available_frameworks = ["pytorch", "pytorch-lightning"]
+
+    def __init__(
+        self,
+        framework: str,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+    ):
+        """Initializes the distributor.
+
+        Parameters
+        ----------
+        framework : str
+            A string indicating whether or not we are using PyTorch or PyTorch
+            Lightning. This could either be the string “pytorch” or ”pytorch-lightning”.

Review Comment:
   Seems like you're using special character backquotes (from Mac by default). It would have to be `"` instead of `“`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063155739


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,307 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    RuntimeError
+        Thrown when the conf value is not a valid boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise RuntimeError(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        self.spark = SparkSession.getActiveSession()
+        if not self.spark:
+            raise RuntimeError("An active SparkSession is required for the distributor.")
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------

Review Comment:
   ```suggestion
           -------
           int
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1053893575


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .utils import get_conf_boolean
+import math
+from typing import Union, Callable
+from pyspark.sql import SparkSession
+
+
+# might need to move into its own file as we look forward.
+class Distributor:

Review Comment:
   This can be the parent class for TensorflowDistributor down the line. Just to clarify, for this specific project, the TensorflowDistributor is out of scope, but that is something that we ideally want to add in the future so just doing some preliminary work to make life easier down the line.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1054100034


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .utils import get_conf_boolean
+import math
+from typing import Union, Callable
+from pyspark.sql import SparkSession
+
+
+# might need to move into its own file as we look forward.
+class Distributor:
+    def __init__(self, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        self.spark = SparkSession.builder.getOrCreate()

Review Comment:
   Is there cases that users may want to specify the sparksession?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] zhengruifeng commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
zhengruifeng commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1054104494


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .utils import get_conf_boolean
+import math
+from typing import Union, Callable
+from pyspark.sql import SparkSession
+
+
+# might need to move into its own file as we look forward.
+class Distributor:

Review Comment:
   please add document, since this is the one of the main APIs.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1055998235


##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -0,0 +1,82 @@
+from pyspark.testing.utils import PySparkTestCase
+from pyspark.ml.torch.distributor import PyTorchDistributor
+from pyspark.sql import SparkSession
+
+# Q: Do you recommend that I use pytest.mark.parametrize? It doesn't seem to be used elsewhere in this code...
+class TestPyTorchDistributor(PySparkTestCase):

Review Comment:
   Can you inherit `ReusedSQLTestCase` instead ? Then I think you can remove setUp and tearDown



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1055794843


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .utils import get_conf_boolean
+import math
+from typing import Union, Callable
+from pyspark.sql import SparkSession
+
+
+# might need to move into its own file as we look forward.
+class Distributor:
+    def __init__(self, num_processes: int = 1, local_mode: bool = True, use_gpu: bool = True):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        self.spark = SparkSession.builder.getOrCreate()

Review Comment:
   There actually might be, let me fix this by adding a new parameter as input



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] WeichenXu123 commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
WeichenXu123 commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1058886036


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# Moved the util functions to this file for now
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                if gpu_amount_raw := self.sc.getConf().get(key):  # mypy error??
+                    task_gpu_amount = int(gpu_amount_raw)
+            else:
+                task_gpu_amount = 1  # for single node clusters

Review Comment:
   Wait, seemingly we should change logic here to be:
   
   if self.sc.master.startswith("local"):
     # TODO: Adding a check that GPU device exists.
     gpu_amount_raw = 1
   else:
     # if `spark.task.resource.gpu.amount` is not set, spark task
     # cannot use gpu
     gpu_amount_raw = int(self.sc.getConf().get("spark.task.resource.gpu.amount", "0"))



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# Moved the util functions to this file for now
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                if gpu_amount_raw := self.sc.getConf().get(key):  # mypy error??
+                    task_gpu_amount = int(gpu_amount_raw)
+            else:
+                task_gpu_amount = 1  # for single node clusters

Review Comment:
   Wait, seemingly we should change logic here to be:
   ```
   if self.sc.master.startswith("local"):
     # TODO: Adding a check that GPU device exists.
     gpu_amount_raw = 1
   else:
     # if `spark.task.resource.gpu.amount` is not set, spark task
     # cannot use gpu
     gpu_amount_raw = int(self.sc.getConf().get("spark.task.resource.gpu.amount", "0"))
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1061233548


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# Moved the util functions to this file for now
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:

Review Comment:
   Add a doc string for this class?



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# Moved the util functions to this file for now
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                if gpu_amount_raw := self.sc.getConf().get(key):  # mypy error??
+                    task_gpu_amount = int(gpu_amount_raw)
+            else:
+                task_gpu_amount = 1  # for single node clusters
+            if task_gpu_amount < 1:
+                raise ValueError(
+                    f"The Spark conf `{key}` has a value "
+                    f"of {task_gpu_amount} but it "
+                    "should not have a value less than 1."
+                )
+            return math.ceil(self.num_processes / task_gpu_amount)
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use PyTorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use PyTorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class PyTorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.

Review Comment:
   Do we support `PyTorch Lightning` in our current version?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1062956832


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# Moved the util functions to this file for now
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                if gpu_amount_raw := self.sc.getConf().get(key):  # mypy error??
+                    task_gpu_amount = int(gpu_amount_raw)
+            else:
+                task_gpu_amount = 1  # for single node clusters
+            if task_gpu_amount < 1:
+                raise ValueError(
+                    f"The Spark conf `{key}` has a value "
+                    f"of {task_gpu_amount} but it "
+                    "should not have a value less than 1."
+                )
+            return math.ceil(self.num_processes / task_gpu_amount)

Review Comment:
   https://issues.apache.org/jira/browse/SPARK-41916



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063123902


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:

Review Comment:
   I believe `bool("false")` would return `True` though since bool doesn't actually attempt to "comprehend" the string value, right? I guess we could do `return sc.getConf().get(key, default_value) == "true"` but that would treat all possible other values here as "false." I guess let's wait for @HyukjinKwon to give his thoughts on the matter.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063037248


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(

Review Comment:
   +1. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] lu-wang-dl commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
lu-wang-dl commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063039894


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:

Review Comment:
   Why not just use something like
   ```
   bool(sc.getConf().get(key, default_value))
   ```
   and let Python handle the issue?
   Then we don't need to create an function.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063013781


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise Exception(

Review Comment:
   ditto can we throw a different exception? 



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if self.num_processes > num_available_gpus:
+                    raise ValueError(
+                        f"For local training, {self.num_processes} can be at most"
+                        f"equal to the amount of GPUs available,"
+                        f"which is {num_available_gpus}."
+                    )
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        NotImplementedError
+            Thrown when the user doesn't use TorchDistributor
+        Exception
+            Thrown when the user requires ssl encryption
+        """
+        if not "ssl_conf":
+            raise Exception(

Review Comment:
   ditto



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on PR #39146:
URL: https://github.com/apache/spark/pull/39146#issuecomment-1373003031

   > Can I only request a review from one person? @HyukjinKwon
   
   I believe you can request the review to one person via:
   ![Screen Shot 2023-01-06 at 10 07 53 AM](https://user-images.githubusercontent.com/6477701/210909183-14d627c7-e885-4b97-bfb0-d4e244f2827d.png)
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063007162


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,287 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# Moved the util functions to this file for now
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    Exception
+        Thrown when the conf value is not a boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise Exception(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+        spark: Optional[SparkSession] = None,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        if spark:
+            self.spark = spark
+        else:
+            self.spark = SparkSession.builder.getOrCreate()
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+        """
+        if self.use_gpu:
+            key = "spark.task.resource.gpu.amount"
+            if self.sc.getConf().contains(key):
+                if gpu_amount_raw := self.sc.getConf().get(key):  # mypy error??
+                    task_gpu_amount = int(gpu_amount_raw)
+            else:
+                task_gpu_amount = 1  # for single node clusters

Review Comment:
   @WeichenXu123 @lu-wang-dl, I added an updated check. Can you please let me know if this looks more reasonable?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] rithwik-db commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
rithwik-db commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063150272


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,307 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
+    """Get the conf "key" from the given spark context,
+    or return the default value if the conf is not set.
+    This expects the conf value to be a boolean or string;
+    if the value is a string, this checks for all capitalization
+    patterns of "true" and "false" to match Scala.
+
+    Parameters
+    ----------
+    sc : SparkContext
+        The SparkContext for the distributor.
+    key : str
+        string for conf name
+    default_value : str
+        default value for the conf value for the given key
+
+    Returns
+    -------
+    bool
+        Returns the boolean value that corresponds to the conf
+
+    Raises
+    ------
+    RuntimeError
+        Thrown when the conf value is not a valid boolean
+    """
+    val = sc.getConf().get(key, default_value)
+    lowercase_val = val.lower()
+    if lowercase_val == "true":
+        return True
+    if lowercase_val == "false":
+        return False
+    raise RuntimeError(
+        "get_conf_boolean expected a boolean conf "
+        "value but found value of type {} "
+        "with value: {}".format(type(val), val)
+    )
+
+
+class Distributor:
+    """
+    The parent class for TorchDistributor. This class shouldn't be instantiated directly.
+    """
+
+    def __init__(
+        self,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+    ):
+        self.num_processes = num_processes
+        self.local_mode = local_mode
+        self.use_gpu = use_gpu
+        self.spark = SparkSession.getActiveSession()
+        if not self.spark:
+            raise RuntimeError("An active SparkSession is required for the distributor.")
+        self.sc = self.spark.sparkContext
+        self.num_tasks = self._get_num_tasks()
+        self.ssl_conf = None
+
+    def _get_num_tasks(self) -> int:
+        """
+        Returns the number of Spark tasks to use for distributed training
+
+        Returns
+        -------
+            The number of Spark tasks to use for distributed training
+
+        Raises
+        ------
+        RuntimeError
+            Raised when the SparkConf was misconfigured.
+        """
+
+        if self.use_gpu:
+            if not self.local_mode:
+                key = "spark.task.resource.gpu.amount"
+                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                if task_gpu_amount < 1:
+                    raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.")
+                # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1
+                return math.ceil(self.num_processes / task_gpu_amount)
+            else:
+                key = "spark.driver.resource.gpu.amount"
+                if "gpu" not in self.sc.resources:
+                    raise RuntimeError("GPUs were unable to be found on the driver.")
+                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                if num_available_gpus == 0:
+                    raise RuntimeError("GPU resources were not configured properly on the driver.")
+                if self.num_processes > num_available_gpus:
+                    warnings.warn(
+                        f"'num_processes' cannot be set to a value greater than the number of "
+                        f"available GPUs on the driver, which is {num_available_gpus}. "
+                        f"'num_processes' was reset to be equal to the number of available GPUs.",
+                        RuntimeWarning,
+                    )
+                    self.num_processes = num_available_gpus
+        return self.num_processes
+
+    def _validate_input_params(self) -> None:
+        if self.num_processes <= 0:
+            raise ValueError("num_proccesses has to be a positive integer")
+
+    def _check_encryption(self) -> None:
+        """Checks to see if the user requires encrpytion of data.
+        If required, throw an exception since we don't support that.
+
+        Raises
+        ------
+        RuntimeError
+            Thrown when the user requires ssl encryption or when the user initializes
+            the Distributor parent class.
+        """
+        if not "ssl_conf":
+            raise RuntimeError(
+                "Distributor doesn't have this functionality. Use TorchDistributor instead."
+            )
+        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", "false")
+        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # type: ignore
+        if is_ssl_enabled:
+            name = self.__class__.__name__
+            if ignore_ssl:
+                warnings.warn(
+                    f"""
+                    This cluster has TLS encryption enabled;
+                    however, {name} does not
+                    support data encryption in transit.
+                    The Spark configuration
+                    '{self.ssl_conf}' has been set to
+                    'true' to override this
+                    configuration and use {name} anyway. Please
+                    note this will cause model
+                    parameters and possibly training data to
+                    be sent between nodes unencrypted.
+                    """,
+                    RuntimeWarning,
+                )
+                return
+            raise RuntimeError(
+                f"""
+                This cluster has TLS encryption enabled;
+                however, {name} does not support
+                data encryption in transit. To override
+                this configuration and use {name}
+                anyway, you may set '{self.ssl_conf}'
+                to 'true' in the Spark configuration. Please note this
+                will cause model parameters and possibly training
+                data to be sent between nodes unencrypted.
+                """
+            )
+
+
+class TorchDistributor(Distributor):
+    """
+    A class to support distributed training on PyTorch and PyTorch Lightning using PySpark.
+
+    .. versionadded:: 3.4.0
+
+    Examples
+    --------
+
+    Run PyTorch Training locally on GPU (using a PyTorch native function)
+
+    >>> def train(learning_rate):
+    ...     import torch.distributed
+    ...     torch.distributed.init_process_group(backend="nccl")
+    ...     # ...
+    ...     torch.destroy_process_group()
+    ...     return model # or anything else
+    >>> distributor = TorchDistributor(
+    ...     framework="pytorch",
+    ...     num_processes=2,
+    ...     local_mode=True,
+    ...     use_gpu=True)
+    >>> model = distributor.run(train, 1e-3)
+
+    Run PyTorch Training on GPU (using a file with PyTorch code)
+
+    >>> distributor = TorchDistributor(
+    ...     framework="pytorch",
+    ...     num_processes=2,
+    ...     local_mode=False,
+    ...     use_gpu=True)
+    >>> distributor.run("/path/to/train.py", *args)
+
+    Run PyTorch Lightning Training on GPU
+
+    >>> num_proc = 2
+    >>> def train():
+    ...     from pytorch_lightning import Trainer
+    ...     # ...
+    ...     # required to set devices = 1 and num_nodes == num_processes
+    ...     trainer = Trainer(accelerator="gpu", devices=1, num_nodes=num_proc, strategy="ddp")
+    ...     trainer.fit()
+    ...     # ...
+    ...     return trainer
+    >>> distributor = TorchDistributor(
+    ...     framework="pytorch-lightning",
+    ...     num_processes=num_proc,
+    ...     local_mode=True,
+    ...     use_gpu=True)
+    >>> trainer = distributor.run(train)
+    """
+
+    # TODO(SPARK-41915): Remove need for setting frameworks in a future PR.
+    available_frameworks = ["pytorch", "pytorch-lightning"]
+
+    def __init__(
+        self,
+        framework: str,
+        num_processes: int = 1,
+        local_mode: bool = True,
+        use_gpu: bool = True,
+    ):
+        """Initializes the distributor.
+
+        Parameters
+        ----------
+        framework : str
+            A string indicating whether or not we are using PyTorch or PyTorch
+            Lightning. This could either be the string “pytorch” or ”pytorch-lightning”.

Review Comment:
   Thanks for catching that!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #39146: [WIP][SPARK-41589][PYTHON][ML] PyTorch Distributor Baseline API Changes

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #39146:
URL: https://github.com/apache/spark/pull/39146#discussion_r1063147515


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -0,0 +1,297 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import math
+from typing import Union, Callable, Optional, Any
+import warnings
+
+from pyspark.sql import SparkSession
+from pyspark.context import SparkContext
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:

Review Comment:
   I know some places we ignore invalid configuration values but I believe it's more common to throw an exception on invalid values so I am fine with throwing an exception. I am fine with the current logic.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org