You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/28 03:36:19 UTC
[spark] branch branch-3.4 updated: [SPARK-41876][CONNECT][PYTHON] Implement DataFrame.toLocalIterator
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 61293faf10c [SPARK-41876][CONNECT][PYTHON] Implement DataFrame.toLocalIterator
61293faf10c is described below
commit 61293faf10ca2b219fda57a9739f6be0c3ffabe1
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Tue Mar 28 12:35:54 2023 +0900
[SPARK-41876][CONNECT][PYTHON] Implement DataFrame.toLocalIterator
### What changes were proposed in this pull request?
Implements `DataFrame.toLocalIterator`.
The argument `prefetchPartitions` won't take effect for Spark Connect.
### Why are the changes needed?
Missing API.
### Does this PR introduce _any_ user-facing change?
`DataFrame.toLocalIterator` will be available.
### How was this patch tested?
Enabled the related tests.
Closes #40570 from ueshin/issues/SPARK-41876/toLocalIterator.
Authored-by: Takuya UESHIN <ue...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
(cherry picked from commit 31965a06c9f85abf2296971237b1f88065eb67c2)
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/sql/connect/client.py | 103 ++++++++++++++-------
python/pyspark/sql/connect/dataframe.py | 26 +++++-
python/pyspark/sql/dataframe.py | 8 +-
.../sql/tests/connect/test_connect_basic.py | 1 -
.../sql/tests/connect/test_parity_dataframe.py | 14 +--
python/pyspark/sql/tests/test_dataframe.py | 21 +++--
6 files changed, 113 insertions(+), 60 deletions(-)
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 6de78e72af8..84889e76103 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -35,6 +35,7 @@ import sys
from types import TracebackType
from typing import (
Iterable,
+ Iterator,
Optional,
Any,
Union,
@@ -625,8 +626,8 @@ class SparkConnectClient(object):
self._execute(req)
- def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]:
- return [
+ def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> Iterator[PlanMetrics]:
+ return (
PlanMetrics(
x.name,
x.plan_id,
@@ -634,18 +635,25 @@ class SparkConnectClient(object):
[MetricValue(k, v.value, v.metric_type) for k, v in x.execution_metrics.items()],
)
for x in metrics.metrics
- ]
+ )
def _build_observed_metrics(
self, metrics: List["pb2.ExecutePlanResponse.ObservedMetrics"]
- ) -> List[PlanObservedMetrics]:
- return [
- PlanObservedMetrics(
- x.name,
- [v for v in x.values],
- )
- for x in metrics
- ]
+ ) -> Iterator[PlanObservedMetrics]:
+ return (PlanObservedMetrics(x.name, [v for v in x.values]) for x in metrics)
+
+ def to_table_as_iterator(self, plan: pb2.Plan) -> Iterator[Union[StructType, "pa.Table"]]:
+ """
+ Return given plan as a PyArrow Table iterator.
+ """
+ logger.info(f"Executing plan {self._proto_to_string(plan)}")
+ req = self._execute_plan_request_with_metadata()
+ req.plan.CopyFrom(plan)
+ for response in self._execute_and_fetch_as_iterator(req):
+ if isinstance(response, StructType):
+ yield response
+ elif isinstance(response, pa.RecordBatch):
+ yield pa.Table.from_batches([response])
def to_table(self, plan: pb2.Plan) -> Tuple["pa.Table", Optional[StructType]]:
"""
@@ -900,46 +908,44 @@ class SparkConnectClient(object):
except grpc.RpcError as rpc_error:
self._handle_error(rpc_error)
- def _execute_and_fetch(
+ def _execute_and_fetch_as_iterator(
self, req: pb2.ExecutePlanRequest
- ) -> Tuple[
- Optional["pa.Table"],
- Optional[StructType],
- List[PlanMetrics],
- List[PlanObservedMetrics],
- Dict[str, Any],
+ ) -> Iterator[
+ Union[
+ "pa.RecordBatch",
+ StructType,
+ PlanMetrics,
+ PlanObservedMetrics,
+ Dict[str, Any],
+ ]
]:
- logger.info("ExecuteAndFetch")
+ logger.info("ExecuteAndFetchAsIterator")
- m: Optional[pb2.ExecutePlanResponse.Metrics] = None
- om: List[pb2.ExecutePlanResponse.ObservedMetrics] = []
- batches: List[pa.RecordBatch] = []
- schema: Optional[StructType] = None
- properties = {}
try:
for attempt in Retrying(
can_retry=SparkConnectClient.retry_exception, **self._retry_policy
):
with attempt:
- batches = []
for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()):
if b.session_id != self._session_id:
raise SparkConnectException(
"Received incorrect session identifier for request: "
f"{b.session_id} != {self._session_id}"
)
- if b.metrics is not None:
+ if b.HasField("metrics"):
logger.debug("Received metric batch.")
- m = b.metrics
- if b.observed_metrics is not None:
+ yield from self._build_metrics(b.metrics)
+ if b.observed_metrics:
logger.debug("Received observed metric batch.")
- om.extend(b.observed_metrics)
+ yield from self._build_observed_metrics(b.observed_metrics)
if b.HasField("schema"):
+ logger.debug("Received the schema.")
dt = types.proto_schema_to_pyspark_data_type(b.schema)
assert isinstance(dt, StructType)
- schema = dt
+ yield dt
if b.HasField("sql_command_result"):
- properties["sql_command_result"] = b.sql_command_result.relation
+ logger.debug("Received the SQL command result.")
+ yield {"sql_command_result": b.sql_command_result.relation}
if b.HasField("arrow_batch"):
logger.debug(
f"Received arrow batch rows={b.arrow_batch.row_count} "
@@ -949,11 +955,40 @@ class SparkConnectClient(object):
with pa.ipc.open_stream(b.arrow_batch.data) as reader:
for batch in reader:
assert isinstance(batch, pa.RecordBatch)
- batches.append(batch)
+ yield batch
except grpc.RpcError as rpc_error:
self._handle_error(rpc_error)
- metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None else []
- observed_metrics: List[PlanObservedMetrics] = self._build_observed_metrics(om)
+
+ def _execute_and_fetch(
+ self, req: pb2.ExecutePlanRequest
+ ) -> Tuple[
+ Optional["pa.Table"],
+ Optional[StructType],
+ List[PlanMetrics],
+ List[PlanObservedMetrics],
+ Dict[str, Any],
+ ]:
+ logger.info("ExecuteAndFetch")
+
+ observed_metrics: List[PlanObservedMetrics] = []
+ metrics: List[PlanMetrics] = []
+ batches: List[pa.RecordBatch] = []
+ schema: Optional[StructType] = None
+ properties: Dict[str, Any] = {}
+
+ for response in self._execute_and_fetch_as_iterator(req):
+ if isinstance(response, StructType):
+ schema = response
+ elif isinstance(response, pa.RecordBatch):
+ batches.append(response)
+ elif isinstance(response, PlanMetrics):
+ metrics.append(response)
+ elif isinstance(response, PlanObservedMetrics):
+ observed_metrics.append(response)
+ elif isinstance(response, dict):
+ properties.update(**response)
+ else:
+ raise ValueError(f"Unknown response: {response}")
if len(batches) > 0:
table = pa.Table.from_batches(batches=batches)
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 2dfc8e72193..1555be7778f 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -21,6 +21,7 @@ check_dependencies(__name__)
from typing import (
Any,
Dict,
+ Iterator,
List,
Optional,
Tuple,
@@ -36,6 +37,7 @@ from typing import (
import sys
import random
import pandas
+import pyarrow as pa
import json
import warnings
from collections.abc import Iterable
@@ -1597,8 +1599,28 @@ class DataFrame:
def foreachPartition(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("foreachPartition() is not implemented.")
- def toLocalIterator(self, *args: Any, **kwargs: Any) -> None:
- raise NotImplementedError("toLocalIterator() is not implemented.")
+ def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]:
+ from pyspark.sql.connect.conversion import ArrowTableToRowsConversion
+
+ if self._plan is None:
+ raise Exception("Cannot collect on empty plan.")
+ if self._session is None:
+ raise Exception("Cannot collect on empty session.")
+ query = self._plan.to_proto(self._session.client)
+
+ schema: Optional[StructType] = None
+ for schema_or_table in self._session.client.to_table_as_iterator(query):
+ if isinstance(schema_or_table, StructType):
+ assert schema is None
+ schema = schema_or_table
+ else:
+ assert isinstance(schema_or_table, pa.Table)
+ table = schema_or_table
+ if schema is None:
+ schema = from_arrow_schema(table.schema)
+ yield from ArrowTableToRowsConversion.convert(table, schema)
+
+ toLocalIterator.__doc__ = PySparkDataFrame.toLocalIterator.__doc__
def checkpoint(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("checkpoint() is not implemented.")
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 4bc32c41a09..cc5d264bd34 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1224,10 +1224,16 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
.. versionadded:: 2.0.0
+ .. versionchanged:: 3.4.0
+ Supports Spark Connect.
+
Parameters
----------
prefetchPartitions : bool, optional
- If Spark should pre-fetch the next partition before it is needed.
+ If Spark should pre-fetch the next partition before it is needed.
+
+ .. versionchanged:: 3.4.0
+ This argument does not take effect for Spark Connect.
Returns
-------
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 79c8dba537c..ef99760e99d 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2832,7 +2832,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
"withWatermark",
"foreach",
"foreachPartition",
- "toLocalIterator",
"checkpoint",
"localCheckpoint",
"_repr_html_",
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index ae812b4ca55..72a97a2a65c 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -67,20 +67,8 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
def test_toDF_with_schema_string(self):
super().test_toDF_with_schema_string()
- # TODO(SPARK-41876): Implement DataFrame `toLocalIterator`
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_to_local_iterator(self):
- super().test_to_local_iterator()
-
- # TODO(SPARK-41876): Implement DataFrame `toLocalIterator`
- @unittest.skip("Fails in Spark Connect, should enable.")
def test_to_local_iterator_not_fully_consumed(self):
- super().test_to_local_iterator_not_fully_consumed()
-
- # TODO(SPARK-41876): Implement DataFrame `toLocalIterator`
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_to_local_iterator_prefetch(self):
- super().test_to_local_iterator_prefetch()
+ self.check_to_local_iterator_not_fully_consumed()
def test_to_pandas_for_array_of_struct(self):
# Spark Connect's implementation is based on Arrow.
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index 0ce32ec4abf..626c282bbb2 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1508,20 +1508,23 @@ class DataFrameTestsMixin:
self.assertEqual(expected, list(it))
def test_to_local_iterator_not_fully_consumed(self):
+ with QuietTest(self.sc):
+ self.check_to_local_iterator_not_fully_consumed()
+
+ def check_to_local_iterator_not_fully_consumed(self):
# SPARK-23961: toLocalIterator throws exception when not fully consumed
# Create a DataFrame large enough so that write to socket will eventually block
df = self.spark.range(1 << 20, numPartitions=2)
it = df.toLocalIterator()
self.assertEqual(df.take(1)[0], next(it))
- with QuietTest(self.sc):
- it = None # remove iterator from scope, socket is closed when cleaned up
- # Make sure normal df operations still work
- result = []
- for i, row in enumerate(df.toLocalIterator()):
- result.append(row)
- if i == 7:
- break
- self.assertEqual(df.take(8), result)
+ it = None # remove iterator from scope, socket is closed when cleaned up
+ # Make sure normal df operations still work
+ result = []
+ for i, row in enumerate(df.toLocalIterator()):
+ result.append(row)
+ if i == 7:
+ break
+ self.assertEqual(df.take(8), result)
def test_same_semantics_error(self):
with QuietTest(self.sc):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org