You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/10 01:09:14 UTC
[spark] branch branch-3.4 updated: [SPARK-42726][CONNECT][PYTHON] Implement `DataFrame.mapInArrow`
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 94a2afce8eb [SPARK-42726][CONNECT][PYTHON] Implement `DataFrame.mapInArrow`
94a2afce8eb is described below
commit 94a2afce8eb290feb34873025ae3065caa6c2b36
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Fri Mar 10 10:08:43 2023 +0900
[SPARK-42726][CONNECT][PYTHON] Implement `DataFrame.mapInArrow`
### What changes were proposed in this pull request?
Implement `DataFrame.mapInArrow`.
### Why are the changes needed?
Parity with vanilla PySpark.
### Does this PR introduce _any_ user-facing change?
Yes. `DataFrame.mapInArrow` is supported as shown below.
```
>>> import pyarrow
>>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
>>> def filter_func(iterator):
... for batch in iterator:
... pdf = batch.to_pandas()
... yield pyarrow.RecordBatch.from_pandas(pdf[pdf.id == 1])
...
>>> df.mapInArrow(filter_func, df.schema).show()
+---+---+
| id|age|
+---+---+
| 1| 21|
+---+---+
```
### How was this patch tested?
Unit tests.
Closes #40350 from xinrong-meng/mapInArrowImpl.
Authored-by: Xinrong Meng <xi...@apache.org>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
(cherry picked from commit f35c2cbdae1c7d35f61b437d056bd363cddbea61)
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../sql/connect/planner/SparkConnectPlanner.scala | 5 +++
dev/sparktestsupport/modules.py | 1 +
python/pyspark/sql/connect/_typing.py | 3 ++
python/pyspark/sql/connect/dataframe.py | 25 +++++++++---
.../sql/tests/connect/test_parity_arrow_map.py | 37 +++++++++++++++++
python/pyspark/sql/tests/test_arrow_map.py | 46 +++++++++++-----------
6 files changed, 90 insertions(+), 27 deletions(-)
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index f2478f548e7..5dd0a7ea309 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -480,6 +480,11 @@ class SparkConnectPlanner(val session: SparkSession) {
pythonUdf,
pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
transformRelation(rel.getInput))
+ case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
+ logical.PythonMapInArrow(
+ pythonUdf,
+ pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
+ transformRelation(rel.getInput))
case _ =>
throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} is not supported")
}
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index d8dee651c2b..751f0687f2c 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -535,6 +535,7 @@ pyspark_connect = Module(
"pyspark.sql.tests.connect.test_parity_udf",
"pyspark.sql.tests.connect.test_parity_pandas_udf",
"pyspark.sql.tests.connect.test_parity_pandas_map",
+ "pyspark.sql.tests.connect.test_parity_arrow_map",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py
index c91d4e629d8..6df3f15d87d 100644
--- a/python/pyspark/sql/connect/_typing.py
+++ b/python/pyspark/sql/connect/_typing.py
@@ -26,6 +26,7 @@ from typing import Any, Callable, Iterable, Union, Optional
import datetime
import decimal
+import pyarrow
from pandas.core.frame import DataFrame as PandasDataFrame
from pyspark.sql.connect.column import Column
@@ -50,6 +51,8 @@ DataFrameLike = PandasDataFrame
PandasMapIterFunction = Callable[[Iterable[DataFrameLike]], Iterable[DataFrameLike]]
+ArrowMapIterFunction = Callable[[Iterable[pyarrow.RecordBatch]], Iterable[pyarrow.RecordBatch]]
+
class UserDefinedFunctionLike(Protocol):
func: Callable[..., Any]
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 69921896f46..0e114f9fedb 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -76,6 +76,7 @@ if TYPE_CHECKING:
PrimitiveType,
OptionalPrimitiveType,
PandasMapIterFunction,
+ ArrowMapIterFunction,
)
from pyspark.sql.connect.session import SparkSession
@@ -1572,8 +1573,11 @@ class DataFrame:
def storageLevel(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("storageLevel() is not implemented.")
- def mapInPandas(
- self, func: "PandasMapIterFunction", schema: Union[StructType, str]
+ def _map_partitions(
+ self,
+ func: "PandasMapIterFunction",
+ schema: Union[StructType, str],
+ evalType: int,
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
@@ -1581,7 +1585,9 @@ class DataFrame:
raise Exception("Cannot mapInPandas when self._plan is empty.")
udf_obj = UserDefinedFunction(
- func, returnType=schema, evalType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
+ func,
+ returnType=schema,
+ evalType=evalType,
)
return DataFrame.withPlan(
@@ -1589,10 +1595,19 @@ class DataFrame:
session=self._session,
)
+ def mapInPandas(
+ self, func: "PandasMapIterFunction", schema: Union[StructType, str]
+ ) -> "DataFrame":
+ return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)
+
mapInPandas.__doc__ = PySparkDataFrame.mapInPandas.__doc__
- def mapInArrow(self, *args: Any, **kwargs: Any) -> None:
- raise NotImplementedError("mapInArrow() is not implemented.")
+ def mapInArrow(
+ self, func: "ArrowMapIterFunction", schema: Union[StructType, str]
+ ) -> "DataFrame":
+ return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_ARROW_ITER_UDF)
+
+ mapInArrow.__doc__ = PySparkDataFrame.mapInArrow.__doc__
def writeStream(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("writeStream() is not implemented.")
diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow_map.py b/python/pyspark/sql/tests/connect/test_parity_arrow_map.py
new file mode 100644
index 00000000000..ed51d0d3d19
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_arrow_map.py
@@ -0,0 +1,37 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.tests.test_arrow_map import MapInArrowTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class ArrowMapParityTests(MapInArrowTestsMixin, ReusedConnectTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.connect.test_parity_arrow_map import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_arrow_map.py b/python/pyspark/sql/tests/test_arrow_map.py
index 6166cc5dcc8..ff3d9b96b6b 100644
--- a/python/pyspark/sql/tests/test_arrow_map.py
+++ b/python/pyspark/sql/tests/test_arrow_map.py
@@ -37,28 +37,7 @@ if have_pandas:
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
-class MapInArrowTests(ReusedSQLTestCase):
- @classmethod
- def setUpClass(cls):
- ReusedSQLTestCase.setUpClass()
-
- # Synchronize default timezone between Python and Java
- cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
- tz = "America/Los_Angeles"
- os.environ["TZ"] = tz
- time.tzset()
-
- cls.sc.environment["TZ"] = tz
- cls.spark.conf.set("spark.sql.session.timeZone", tz)
-
- @classmethod
- def tearDownClass(cls):
- del os.environ["TZ"]
- if cls.tz_prev is not None:
- os.environ["TZ"] = cls.tz_prev
- time.tzset()
- ReusedSQLTestCase.tearDownClass()
-
+class MapInArrowTestsMixin(object):
def test_map_in_arrow(self):
def func(iterator):
for batch in iterator:
@@ -126,6 +105,29 @@ class MapInArrowTests(ReusedSQLTestCase):
self.assertEqual(sorted(actual), sorted(expected))
+class MapInArrowTests(MapInArrowTestsMixin, ReusedSQLTestCase):
+ @classmethod
+ def setUpClass(cls):
+ ReusedSQLTestCase.setUpClass()
+
+ # Synchronize default timezone between Python and Java
+ cls.tz_prev = os.environ.get("TZ", None) # save current tz if set
+ tz = "America/Los_Angeles"
+ os.environ["TZ"] = tz
+ time.tzset()
+
+ cls.sc.environment["TZ"] = tz
+ cls.spark.conf.set("spark.sql.session.timeZone", tz)
+
+ @classmethod
+ def tearDownClass(cls):
+ del os.environ["TZ"]
+ if cls.tz_prev is not None:
+ os.environ["TZ"] = cls.tz_prev
+ time.tzset()
+ ReusedSQLTestCase.tearDownClass()
+
+
if __name__ == "__main__":
from pyspark.sql.tests.test_arrow_map import * # noqa: F401
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org