You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "zhengruifeng (via GitHub)" <gi...@apache.org> on 2023/04/10 11:17:59 UTC

[GitHub] [spark] zhengruifeng commented on a diff in pull request #40724: [SPARK-43081] [ML] [CONNECT] Add torch distributor data loader that loads data from spark partition data

zhengruifeng commented on code in PR #40724:
URL: https://github.com/apache/spark/pull/40724#discussion_r1161636110


##########
python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py:
##########
@@ -0,0 +1,51 @@
+#

Review Comment:
   we need to add the two test files into `modules.py`



##########
python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py:
##########
@@ -0,0 +1,51 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import tempfile
+import unittest
+from pyspark.sql import SparkSession
+
+have_torch = True
+try:
+    import torch  # noqa: F401
+except ImportError:
+    have_torch = False
+
+from pyspark.ml.torch.tests.test_data_loader import TorchDistributorDataLoaderUnitTests
+
+
+@unittest.skipIf(not have_torch, "torch is required")
+class TorchDistributorBaselineUnitTestsOnConnect(TorchDistributorDataLoaderUnitTests):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder \
+            .remote("local[1]") \
+            .config("spark.default.parallelism", "1") \
+            .getOrCreate()

Review Comment:
   ```suggestion
           self.spark = (
               SparkSession.builder.remote("local[1]")
               .config("spark.default.parallelism", "1")
               .getOrCreate()
           )
   ```



##########
python/pyspark/ml/torch/data.py:
##########
@@ -0,0 +1,78 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+import numpy as np
+
+
+class SparkPartitionTorchDataset(torch.utils.data.IterableDataset):
+
+    def __init__(self, arrow_file_path, schema, num_samples):
+        self.arrow_file_path = arrow_file_path
+        self.num_samples = num_samples
+        self.field_types = [field.dataType.simpleString() for field in schema]

Review Comment:
   why we need this? arrow data itself already contains schema



##########
python/pyspark/ml/torch/data.py:
##########
@@ -0,0 +1,78 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+import numpy as np
+
+
+class SparkPartitionTorchDataset(torch.utils.data.IterableDataset):
+
+    def __init__(self, arrow_file_path, schema, num_samples):
+        self.arrow_file_path = arrow_file_path
+        self.num_samples = num_samples
+        self.field_types = [field.dataType.simpleString() for field in schema]
+
+    @staticmethod
+    def _extract_field_value(value, field_type):
+        # TODO: avoid checking field type for every row.
+        if field_type == "vector":
+            if value['type'] == 1:
+                # dense vector
+                return value['values']
+            if value['type'] == 0:
+                # sparse vector
+                size = int(value['size'])
+                np_array = np.zeros(size, dtype=np.float64)
+                for index, elem_value in zip(value['indices'], value['values']):
+                    np_array[index] = elem_value
+                return np_array
+        if field_type in ["float", "double", "int", "bigint", "smallint"]:
+            return value
+
+        raise ValueError(
+            "SparkPartitionTorchDataset does not support loading data from field of "
+            f"type {field_type}."
+        )
+
+    def __iter__(self):
+        from pyspark.sql.pandas.serializers import ArrowStreamSerializer
+        serializer = ArrowStreamSerializer()
+
+        worker_info = torch.utils.data.get_worker_info()
+        if worker_info is not None and worker_info.num_workers > 1:
+            raise RuntimeError(
+                "`SparkPartitionTorchDataset` does not support multiple worker processes."
+            )
+
+        count = 0
+
+        while count < self.num_samples:
+            with open(self.arrow_file_path, "rb") as f:
+                batch_iter = serializer.load_stream(f)
+                for batch in batch_iter:
+                    # TODO: we can optimize this further by directly extracting

Review Comment:
   What about changing the return type from `List` to `Row`?
   Then I guess we can reuse 
   
   https://github.com/apache/spark/blob/0e9e34c1bd9bd16ad5efca77ce2763eb950f3103/python/pyspark/sql/connect/dataframe.py#L1468-L1474
   
   which converts arrow table to `Row`s



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -619,8 +638,10 @@ def _run_distributed_training(
 
     @staticmethod
     def _run_training_on_pytorch_file(

Review Comment:
   since `input_dataframe` is only supported when `train_object` is a function, do we need to add a check in `run`? 



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -744,7 +800,44 @@ def run(self, train_object: Union[Callable, str], *args: Any) -> Optional[Any]:
                 TorchDistributor._run_training_on_pytorch_function  # type: ignore
             )
         if self.local_mode:
-            output = self._run_local_training(framework_wrapper_fn, train_object, *args)
+            output = self._run_local_training(framework_wrapper_fn, train_object, *args, **kwargs)
         else:
-            output = self._run_distributed_training(framework_wrapper_fn, train_object, *args)
+            output = self._run_distributed_training(
+                framework_wrapper_fn, train_object, None, *args, **kwargs
+            )
         return output
+
+    def train_on_dataframe(self, train_function, spark_dataframe, *args, **kwargs):
+        if self.local_mode:
+            raise ValueError(
+                "`TorchDistributor.train_on_dataframe` requires setting `TorchDistributor.local_mode` to `False`."
+            )
+
+        return self._run_distributed_training(
+            TorchDistributor._run_training_on_pytorch_function,
+            train_function,
+            spark_dataframe,
+            *args,
+            **kwargs
+        )
+
+
+def get_spark_partition_data_loader(num_samples, batch_size, prefetch=2):

Review Comment:
   why `prefetch=2`?



##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -643,13 +664,48 @@ def _setup_files(train_fn: Callable, *args: Any) -> Generator[Tuple[str, str], N
         finally:
             TorchDistributor._cleanup_files(save_dir)
 
+    @staticmethod
+    @contextmanager
+    def _setup_spark_partition_data(partition_data_iterator, input_schema_json):
+        from pyspark.sql.pandas.serializers import ArrowStreamSerializer
+        import json
+
+        if input_schema_json is None:
+            yield
+            return
+
+        save_dir = TorchDistributor._create_save_dir()
+
+        try:
+            serializer = ArrowStreamSerializer()

Review Comment:
   ditto, I guess we can reuse `LocalDataToArrowConversion` to convert rows to arrow table



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org