You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "jaceklaskowski (via GitHub)" <gi...@apache.org> on 2023/04/11 14:14:14 UTC

[GitHub] [spark] jaceklaskowski commented on a diff in pull request #40724: [SPARK-43081] [ML] [CONNECT] Add torch distributor data loader that loads data from spark partition data

jaceklaskowski commented on code in PR #40724:
URL: https://github.com/apache/spark/pull/40724#discussion_r1162873700


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -643,13 +664,52 @@ def _setup_files(train_fn: Callable, *args: Any) -> Generator[Tuple[str, str], N
         finally:
             TorchDistributor._cleanup_files(save_dir)
 
+    @staticmethod
+    @contextmanager
+    def _setup_spark_partition_data(partition_data_iterator, input_schema_json):
+        from pyspark.sql.pandas.serializers import ArrowStreamSerializer
+        from pyspark.files import SparkFiles
+        import json
+
+        if input_schema_json is None:
+            yield
+            return
+
+        # We need to temporarily write partition data into a temp dir,
+        # partition data might be huge, so we need to write it under
+        # configured `SPARK_LOCAL_DIRS`.
+        save_dir = TorchDistributor._create_save_dir(root_dir=SparkFiles.getRootDirectory())
+
+        try:
+            serializer = ArrowStreamSerializer()
+            arrow_file_path = os.path.join(save_dir, "data.arrow")
+            with open(arrow_file_path, "wb") as f:
+                serializer.dump_stream(partition_data_iterator, f)
+                if f.tell() == 0:
+                    # Nothing is written to file, this partition is empty
+                    raise ValueError(
+                        "Empty spark DataFrame partition is not allowed if you run "

Review Comment:
   nit: Spark partition (uppercase + no need for DataFrame)? Or just `DataFrame partition`? I'd also consider "is not allowed in TorchDistributor.train_on_dataframe".



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -731,11 +791,21 @@ def run(self, train_object: Union[Callable, str], *args: Any) -> Optional[Any]:
 
             where since the input is a path, all of the parameters are strings that can be
             handled by argparse in that python file.
+        kwargs :
+            If train_object is a python function and not a path to a python file, kwargs need
+            to be the key-work input parameters to that function. It would look like

Review Comment:
   nit: `key-work` not `key-word`?



##########
python/pyspark/ml/torch/data.py:
##########
@@ -0,0 +1,78 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+import numpy as np
+
+
+class SparkPartitionTorchDataset(torch.utils.data.IterableDataset):
+
+    def __init__(self, arrow_file_path, schema, num_samples):
+        self.arrow_file_path = arrow_file_path
+        self.num_samples = num_samples
+        self.field_types = [field.dataType.simpleString() for field in schema]
+
+    @staticmethod
+    def _extract_field_value(value, field_type):
+        # TODO: avoid checking field type for every row.
+        if field_type == "vector":
+            if value['type'] == 1:
+                # dense vector
+                return value['values']
+            if value['type'] == 0:
+                # sparse vector
+                size = int(value['size'])
+                np_array = np.zeros(size, dtype=np.float64)
+                for index, elem_value in zip(value['indices'], value['values']):
+                    np_array[index] = elem_value
+                return np_array
+        if field_type in ["float", "double", "int", "bigint", "smallint"]:
+            return value
+
+        raise ValueError(
+            "SparkPartitionTorchDataset does not support loading data from field of "
+            f"type {field_type}."
+        )
+
+    def __iter__(self):
+        from pyspark.sql.pandas.serializers import ArrowStreamSerializer
+        serializer = ArrowStreamSerializer()
+
+        worker_info = torch.utils.data.get_worker_info()
+        if worker_info is not None and worker_info.num_workers > 1:
+            raise RuntimeError(
+                "`SparkPartitionTorchDataset` does not support multiple worker processes."

Review Comment:
   nit: remove backticks?



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -744,7 +814,99 @@ def run(self, train_object: Union[Callable, str], *args: Any) -> Optional[Any]:
                 TorchDistributor._run_training_on_pytorch_function  # type: ignore
             )
         if self.local_mode:
-            output = self._run_local_training(framework_wrapper_fn, train_object, *args)
+            output = self._run_local_training(framework_wrapper_fn, train_object, *args, **kwargs)
         else:
-            output = self._run_distributed_training(framework_wrapper_fn, train_object, *args)
+            output = self._run_distributed_training(
+                framework_wrapper_fn, train_object, None, *args, **kwargs
+            )
         return output
+
+    def train_on_dataframe(self, train_function, spark_dataframe, *args, **kwargs):
+        """
+        Runs distributed training using provided spark DataFrame as input data.
+        You should ensure the input spark DataFrame have evenly divided partitions,

Review Comment:
   Spark + has
   
   "divided" or "distributed"?



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -744,7 +814,99 @@ def run(self, train_object: Union[Callable, str], *args: Any) -> Optional[Any]:
                 TorchDistributor._run_training_on_pytorch_function  # type: ignore
             )
         if self.local_mode:
-            output = self._run_local_training(framework_wrapper_fn, train_object, *args)
+            output = self._run_local_training(framework_wrapper_fn, train_object, *args, **kwargs)
         else:
-            output = self._run_distributed_training(framework_wrapper_fn, train_object, *args)
+            output = self._run_distributed_training(
+                framework_wrapper_fn, train_object, None, *args, **kwargs
+            )
         return output
+
+    def train_on_dataframe(self, train_function, spark_dataframe, *args, **kwargs):
+        """
+        Runs distributed training using provided spark DataFrame as input data.
+        You should ensure the input spark DataFrame have evenly divided partitions,
+        and this method starts a barrier spark job that each spark task in the job
+        process one partition of the input spark DataFrame.
+
+        Parameters
+        ----------
+        train_function :
+            Either a PyTorch function, PyTorch Lightning function that launches distributed
+            training. Note that inside the function, you can call
+            `pyspark.ml.torch.distributor.get_spark_partition_data_loader` API to get a torch
+            data loader, the data loader loads data from the corresponding partition of the
+            input spark DataFrame.
+        spark_dataframe :
+            An input spark DataFrame that can be used in PyTorch `train_function` function.
+            See `train_function` argument doc for details.
+        args :
+            `args` need to be the input parameters to `train_function` function. It would look like
+
+            >>> model = distributor.run(train, 1e-3, 64)
+
+            where train is a function and 1e-3 and 64 are regular numeric inputs to the function.
+        kwargs :
+            `kwargs` need to be the key-work input parameters to `train_function` function.

Review Comment:
   work?



##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -431,7 +431,7 @@ def test_dist_training_succeeds(self) -> None:
                 )
                 self.assertEqual(
                     expected,
-                    dist._run_distributed_training(dist._run_training_on_pytorch_file, "..."),
+                    dist._run_distributed_training(dist._run_training_on_pytorch_file, "...", None),

Review Comment:
   Possible to have `None` as the default?



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -643,13 +664,52 @@ def _setup_files(train_fn: Callable, *args: Any) -> Generator[Tuple[str, str], N
         finally:
             TorchDistributor._cleanup_files(save_dir)
 
+    @staticmethod
+    @contextmanager
+    def _setup_spark_partition_data(partition_data_iterator, input_schema_json):
+        from pyspark.sql.pandas.serializers import ArrowStreamSerializer
+        from pyspark.files import SparkFiles
+        import json
+
+        if input_schema_json is None:
+            yield
+            return
+
+        # We need to temporarily write partition data into a temp dir,
+        # partition data might be huge, so we need to write it under
+        # configured `SPARK_LOCAL_DIRS`.
+        save_dir = TorchDistributor._create_save_dir(root_dir=SparkFiles.getRootDirectory())
+
+        try:
+            serializer = ArrowStreamSerializer()
+            arrow_file_path = os.path.join(save_dir, "data.arrow")
+            with open(arrow_file_path, "wb") as f:
+                serializer.dump_stream(partition_data_iterator, f)
+                if f.tell() == 0:
+                    # Nothing is written to file, this partition is empty
+                    raise ValueError(
+                        "Empty spark DataFrame partition is not allowed if you run "
+                        "`TorchDistributor.train_on_dataframe`."
+                    )
+
+            schema_file_path = os.path.join(save_dir, "schema.json")
+            with open(schema_file_path, "w") as f:
+                json.dump(input_schema_json, f)
+
+            os.environ[SPARK_PARTITION_ARROW_DATA_FILE] = arrow_file_path
+            os.environ[SPARK_DATAFRAME_SCHEMA_FILE] = schema_file_path
+            yield
+        finally:
+            os.environ.pop(SPARK_PARTITION_ARROW_DATA_FILE)
+            os.environ.pop(SPARK_DATAFRAME_SCHEMA_FILE)
+            TorchDistributor._cleanup_files(save_dir)
+
     @staticmethod
     def _run_training_on_pytorch_function(
-        input_params: Dict[str, Any], train_fn: Callable, *args: Any
+        input_params: Dict[str, Any], train_fn: Callable, *args: Any, **kwargs
     ) -> Any:
-        with TorchDistributor._setup_files(train_fn, *args) as (train_file_path, output_file_path):
-            args = []  # type: ignore
-            TorchDistributor._run_training_on_pytorch_file(input_params, train_file_path, *args)
+        with TorchDistributor._setup_files(train_fn, *args, **kwargs) as (train_file_path, output_file_path):
+            TorchDistributor._run_training_on_pytorch_file(input_params, train_file_path)
             if not os.path.exists(output_file_path):
                 raise RuntimeError(
                     "TorchDistributor failed during training. "

Review Comment:
   nit: Remove the space after `.`?



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -731,11 +791,21 @@ def run(self, train_object: Union[Callable, str], *args: Any) -> Optional[Any]:
 
             where since the input is a path, all of the parameters are strings that can be
             handled by argparse in that python file.
+        kwargs :
+            If train_object is a python function and not a path to a python file, kwargs need
+            to be the key-work input parameters to that function. It would look like
+
+            >>> model = distributor.run(train, tol=1e-3, max_iter=64)
+
+            where train is a function that has 2 arguments `tol` and `max_iter`.

Review Comment:
   nit: s/that has/of



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -744,7 +814,99 @@ def run(self, train_object: Union[Callable, str], *args: Any) -> Optional[Any]:
                 TorchDistributor._run_training_on_pytorch_function  # type: ignore
             )
         if self.local_mode:
-            output = self._run_local_training(framework_wrapper_fn, train_object, *args)
+            output = self._run_local_training(framework_wrapper_fn, train_object, *args, **kwargs)
         else:
-            output = self._run_distributed_training(framework_wrapper_fn, train_object, *args)
+            output = self._run_distributed_training(
+                framework_wrapper_fn, train_object, None, *args, **kwargs
+            )
         return output
+
+    def train_on_dataframe(self, train_function, spark_dataframe, *args, **kwargs):
+        """
+        Runs distributed training using provided spark DataFrame as input data.

Review Comment:
   nit: s/spark/Spark



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org