You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/10 01:09:00 UTC

[spark] branch master updated: [SPARK-42726][CONNECT][PYTHON] Implement `DataFrame.mapInArrow`

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f35c2cbdae1 [SPARK-42726][CONNECT][PYTHON] Implement `DataFrame.mapInArrow`
f35c2cbdae1 is described below

commit f35c2cbdae1c7d35f61b437d056bd363cddbea61
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Fri Mar 10 10:08:43 2023 +0900

    [SPARK-42726][CONNECT][PYTHON] Implement `DataFrame.mapInArrow`
    
    ### What changes were proposed in this pull request?
    Implement `DataFrame.mapInArrow`.
    
    ### Why are the changes needed?
    Parity with vanilla PySpark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. `DataFrame.mapInArrow` is supported as shown below.
    
    ```
    >>> import pyarrow
    >>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
    >>> def filter_func(iterator):
    ...   for batch in iterator:
    ...     pdf = batch.to_pandas()
    ...     yield pyarrow.RecordBatch.from_pandas(pdf[pdf.id == 1])
    ...
    >>> df.mapInArrow(filter_func, df.schema).show()
    +---+---+
    | id|age|
    +---+---+
    |  1| 21|
    +---+---+
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #40350 from xinrong-meng/mapInArrowImpl.
    
    Authored-by: Xinrong Meng <xi...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../sql/connect/planner/SparkConnectPlanner.scala  |  5 +++
 dev/sparktestsupport/modules.py                    |  1 +
 python/pyspark/sql/connect/_typing.py              |  3 ++
 python/pyspark/sql/connect/dataframe.py            | 25 +++++++++---
 .../sql/tests/connect/test_parity_arrow_map.py     | 37 +++++++++++++++++
 python/pyspark/sql/tests/test_arrow_map.py         | 46 +++++++++++-----------
 6 files changed, 90 insertions(+), 27 deletions(-)

diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index f2478f548e7..5dd0a7ea309 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -480,6 +480,11 @@ class SparkConnectPlanner(val session: SparkSession) {
           pythonUdf,
           pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
           transformRelation(rel.getInput))
+      case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
+        logical.PythonMapInArrow(
+          pythonUdf,
+          pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
+          transformRelation(rel.getInput))
       case _ =>
         throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} is not supported")
     }
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index d8dee651c2b..751f0687f2c 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -535,6 +535,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_parity_udf",
         "pyspark.sql.tests.connect.test_parity_pandas_udf",
         "pyspark.sql.tests.connect.test_parity_pandas_map",
+        "pyspark.sql.tests.connect.test_parity_arrow_map",
     ],
     excluded_python_implementations=[
         "PyPy"  # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py
index c91d4e629d8..6df3f15d87d 100644
--- a/python/pyspark/sql/connect/_typing.py
+++ b/python/pyspark/sql/connect/_typing.py
@@ -26,6 +26,7 @@ from typing import Any, Callable, Iterable, Union, Optional
 import datetime
 import decimal
 
+import pyarrow
 from pandas.core.frame import DataFrame as PandasDataFrame
 
 from pyspark.sql.connect.column import Column
@@ -50,6 +51,8 @@ DataFrameLike = PandasDataFrame
 
 PandasMapIterFunction = Callable[[Iterable[DataFrameLike]], Iterable[DataFrameLike]]
 
+ArrowMapIterFunction = Callable[[Iterable[pyarrow.RecordBatch]], Iterable[pyarrow.RecordBatch]]
+
 
 class UserDefinedFunctionLike(Protocol):
     func: Callable[..., Any]
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 69921896f46..0e114f9fedb 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -76,6 +76,7 @@ if TYPE_CHECKING:
         PrimitiveType,
         OptionalPrimitiveType,
         PandasMapIterFunction,
+        ArrowMapIterFunction,
     )
     from pyspark.sql.connect.session import SparkSession
 
@@ -1572,8 +1573,11 @@ class DataFrame:
     def storageLevel(self, *args: Any, **kwargs: Any) -> None:
         raise NotImplementedError("storageLevel() is not implemented.")
 
-    def mapInPandas(
-        self, func: "PandasMapIterFunction", schema: Union[StructType, str]
+    def _map_partitions(
+        self,
+        func: "PandasMapIterFunction",
+        schema: Union[StructType, str],
+        evalType: int,
     ) -> "DataFrame":
         from pyspark.sql.connect.udf import UserDefinedFunction
 
@@ -1581,7 +1585,9 @@ class DataFrame:
             raise Exception("Cannot mapInPandas when self._plan is empty.")
 
         udf_obj = UserDefinedFunction(
-            func, returnType=schema, evalType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
+            func,
+            returnType=schema,
+            evalType=evalType,
         )
 
         return DataFrame.withPlan(
@@ -1589,10 +1595,19 @@ class DataFrame:
             session=self._session,
         )
 
+    def mapInPandas(
+        self, func: "PandasMapIterFunction", schema: Union[StructType, str]
+    ) -> "DataFrame":
+        return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)
+
     mapInPandas.__doc__ = PySparkDataFrame.mapInPandas.__doc__
 
-    def mapInArrow(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("mapInArrow() is not implemented.")
+    def mapInArrow(
+        self, func: "ArrowMapIterFunction", schema: Union[StructType, str]
+    ) -> "DataFrame":
+        return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_ARROW_ITER_UDF)
+
+    mapInArrow.__doc__ = PySparkDataFrame.mapInArrow.__doc__
 
     def writeStream(self, *args: Any, **kwargs: Any) -> None:
         raise NotImplementedError("writeStream() is not implemented.")
diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow_map.py b/python/pyspark/sql/tests/connect/test_parity_arrow_map.py
new file mode 100644
index 00000000000..ed51d0d3d19
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_arrow_map.py
@@ -0,0 +1,37 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.tests.test_arrow_map import MapInArrowTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class ArrowMapParityTests(MapInArrowTestsMixin, ReusedConnectTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.test_parity_arrow_map import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_arrow_map.py b/python/pyspark/sql/tests/test_arrow_map.py
index 6166cc5dcc8..ff3d9b96b6b 100644
--- a/python/pyspark/sql/tests/test_arrow_map.py
+++ b/python/pyspark/sql/tests/test_arrow_map.py
@@ -37,28 +37,7 @@ if have_pandas:
     not have_pandas or not have_pyarrow,
     pandas_requirement_message or pyarrow_requirement_message,
 )
-class MapInArrowTests(ReusedSQLTestCase):
-    @classmethod
-    def setUpClass(cls):
-        ReusedSQLTestCase.setUpClass()
-
-        # Synchronize default timezone between Python and Java
-        cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
-        tz = "America/Los_Angeles"
-        os.environ["TZ"] = tz
-        time.tzset()
-
-        cls.sc.environment["TZ"] = tz
-        cls.spark.conf.set("spark.sql.session.timeZone", tz)
-
-    @classmethod
-    def tearDownClass(cls):
-        del os.environ["TZ"]
-        if cls.tz_prev is not None:
-            os.environ["TZ"] = cls.tz_prev
-        time.tzset()
-        ReusedSQLTestCase.tearDownClass()
-
+class MapInArrowTestsMixin(object):
     def test_map_in_arrow(self):
         def func(iterator):
             for batch in iterator:
@@ -126,6 +105,29 @@ class MapInArrowTests(ReusedSQLTestCase):
         self.assertEqual(sorted(actual), sorted(expected))
 
 
+class MapInArrowTests(MapInArrowTestsMixin, ReusedSQLTestCase):
+    @classmethod
+    def setUpClass(cls):
+        ReusedSQLTestCase.setUpClass()
+
+        # Synchronize default timezone between Python and Java
+        cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
+        tz = "America/Los_Angeles"
+        os.environ["TZ"] = tz
+        time.tzset()
+
+        cls.sc.environment["TZ"] = tz
+        cls.spark.conf.set("spark.sql.session.timeZone", tz)
+
+    @classmethod
+    def tearDownClass(cls):
+        del os.environ["TZ"]
+        if cls.tz_prev is not None:
+            os.environ["TZ"] = cls.tz_prev
+        time.tzset()
+        ReusedSQLTestCase.tearDownClass()
+
+
 if __name__ == "__main__":
     from pyspark.sql.tests.test_arrow_map import *  # noqa: F401
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org