You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/28 03:36:07 UTC

[spark] branch master updated: [SPARK-41876][CONNECT][PYTHON] Implement DataFrame.toLocalIterator

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 31965a06c9f [SPARK-41876][CONNECT][PYTHON] Implement DataFrame.toLocalIterator
31965a06c9f is described below

commit 31965a06c9f85abf2296971237b1f88065eb67c2
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Tue Mar 28 12:35:54 2023 +0900

    [SPARK-41876][CONNECT][PYTHON] Implement DataFrame.toLocalIterator
    
    ### What changes were proposed in this pull request?
    
    Implements `DataFrame.toLocalIterator`.
    
    The argument `prefetchPartitions` won't take effect for Spark Connect.
    
    ### Why are the changes needed?
    
    Missing API.
    
    ### Does this PR introduce _any_ user-facing change?
    
    `DataFrame.toLocalIterator` will be available.
    
    ### How was this patch tested?
    
    Enabled the related tests.
    
    Closes #40570 from ueshin/issues/SPARK-41876/toLocalIterator.
    
    Authored-by: Takuya UESHIN <ue...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/sql/connect/client.py               | 103 ++++++++++++++-------
 python/pyspark/sql/connect/dataframe.py            |  26 +++++-
 python/pyspark/sql/dataframe.py                    |   8 +-
 .../sql/tests/connect/test_connect_basic.py        |   1 -
 .../sql/tests/connect/test_parity_dataframe.py     |  14 +--
 python/pyspark/sql/tests/test_dataframe.py         |  21 +++--
 6 files changed, 113 insertions(+), 60 deletions(-)

diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 6de78e72af8..84889e76103 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -35,6 +35,7 @@ import sys
 from types import TracebackType
 from typing import (
     Iterable,
+    Iterator,
     Optional,
     Any,
     Union,
@@ -625,8 +626,8 @@ class SparkConnectClient(object):
 
         self._execute(req)
 
-    def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]:
-        return [
+    def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> Iterator[PlanMetrics]:
+        return (
             PlanMetrics(
                 x.name,
                 x.plan_id,
@@ -634,18 +635,25 @@ class SparkConnectClient(object):
                 [MetricValue(k, v.value, v.metric_type) for k, v in x.execution_metrics.items()],
             )
             for x in metrics.metrics
-        ]
+        )
 
     def _build_observed_metrics(
         self, metrics: List["pb2.ExecutePlanResponse.ObservedMetrics"]
-    ) -> List[PlanObservedMetrics]:
-        return [
-            PlanObservedMetrics(
-                x.name,
-                [v for v in x.values],
-            )
-            for x in metrics
-        ]
+    ) -> Iterator[PlanObservedMetrics]:
+        return (PlanObservedMetrics(x.name, [v for v in x.values]) for x in metrics)
+
+    def to_table_as_iterator(self, plan: pb2.Plan) -> Iterator[Union[StructType, "pa.Table"]]:
+        """
+        Return given plan as a PyArrow Table iterator.
+        """
+        logger.info(f"Executing plan {self._proto_to_string(plan)}")
+        req = self._execute_plan_request_with_metadata()
+        req.plan.CopyFrom(plan)
+        for response in self._execute_and_fetch_as_iterator(req):
+            if isinstance(response, StructType):
+                yield response
+            elif isinstance(response, pa.RecordBatch):
+                yield pa.Table.from_batches([response])
 
     def to_table(self, plan: pb2.Plan) -> Tuple["pa.Table", Optional[StructType]]:
         """
@@ -900,46 +908,44 @@ class SparkConnectClient(object):
         except grpc.RpcError as rpc_error:
             self._handle_error(rpc_error)
 
-    def _execute_and_fetch(
+    def _execute_and_fetch_as_iterator(
         self, req: pb2.ExecutePlanRequest
-    ) -> Tuple[
-        Optional["pa.Table"],
-        Optional[StructType],
-        List[PlanMetrics],
-        List[PlanObservedMetrics],
-        Dict[str, Any],
+    ) -> Iterator[
+        Union[
+            "pa.RecordBatch",
+            StructType,
+            PlanMetrics,
+            PlanObservedMetrics,
+            Dict[str, Any],
+        ]
     ]:
-        logger.info("ExecuteAndFetch")
+        logger.info("ExecuteAndFetchAsIterator")
 
-        m: Optional[pb2.ExecutePlanResponse.Metrics] = None
-        om: List[pb2.ExecutePlanResponse.ObservedMetrics] = []
-        batches: List[pa.RecordBatch] = []
-        schema: Optional[StructType] = None
-        properties = {}
         try:
             for attempt in Retrying(
                 can_retry=SparkConnectClient.retry_exception, **self._retry_policy
             ):
                 with attempt:
-                    batches = []
                     for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()):
                         if b.session_id != self._session_id:
                             raise SparkConnectException(
                                 "Received incorrect session identifier for request: "
                                 f"{b.session_id} != {self._session_id}"
                             )
-                        if b.metrics is not None:
+                        if b.HasField("metrics"):
                             logger.debug("Received metric batch.")
-                            m = b.metrics
-                        if b.observed_metrics is not None:
+                            yield from self._build_metrics(b.metrics)
+                        if b.observed_metrics:
                             logger.debug("Received observed metric batch.")
-                            om.extend(b.observed_metrics)
+                            yield from self._build_observed_metrics(b.observed_metrics)
                         if b.HasField("schema"):
+                            logger.debug("Received the schema.")
                             dt = types.proto_schema_to_pyspark_data_type(b.schema)
                             assert isinstance(dt, StructType)
-                            schema = dt
+                            yield dt
                         if b.HasField("sql_command_result"):
-                            properties["sql_command_result"] = b.sql_command_result.relation
+                            logger.debug("Received the SQL command result.")
+                            yield {"sql_command_result": b.sql_command_result.relation}
                         if b.HasField("arrow_batch"):
                             logger.debug(
                                 f"Received arrow batch rows={b.arrow_batch.row_count} "
@@ -949,11 +955,40 @@ class SparkConnectClient(object):
                             with pa.ipc.open_stream(b.arrow_batch.data) as reader:
                                 for batch in reader:
                                     assert isinstance(batch, pa.RecordBatch)
-                                    batches.append(batch)
+                                    yield batch
         except grpc.RpcError as rpc_error:
             self._handle_error(rpc_error)
-        metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None else []
-        observed_metrics: List[PlanObservedMetrics] = self._build_observed_metrics(om)
+
+    def _execute_and_fetch(
+        self, req: pb2.ExecutePlanRequest
+    ) -> Tuple[
+        Optional["pa.Table"],
+        Optional[StructType],
+        List[PlanMetrics],
+        List[PlanObservedMetrics],
+        Dict[str, Any],
+    ]:
+        logger.info("ExecuteAndFetch")
+
+        observed_metrics: List[PlanObservedMetrics] = []
+        metrics: List[PlanMetrics] = []
+        batches: List[pa.RecordBatch] = []
+        schema: Optional[StructType] = None
+        properties: Dict[str, Any] = {}
+
+        for response in self._execute_and_fetch_as_iterator(req):
+            if isinstance(response, StructType):
+                schema = response
+            elif isinstance(response, pa.RecordBatch):
+                batches.append(response)
+            elif isinstance(response, PlanMetrics):
+                metrics.append(response)
+            elif isinstance(response, PlanObservedMetrics):
+                observed_metrics.append(response)
+            elif isinstance(response, dict):
+                properties.update(**response)
+            else:
+                raise ValueError(f"Unknown response: {response}")
 
         if len(batches) > 0:
             table = pa.Table.from_batches(batches=batches)
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 10426c3c28d..65f270f21d4 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -21,6 +21,7 @@ check_dependencies(__name__)
 from typing import (
     Any,
     Dict,
+    Iterator,
     List,
     Optional,
     Tuple,
@@ -36,6 +37,7 @@ from typing import (
 import sys
 import random
 import pandas
+import pyarrow as pa
 import json
 import warnings
 from collections.abc import Iterable
@@ -1597,8 +1599,28 @@ class DataFrame:
     def foreachPartition(self, *args: Any, **kwargs: Any) -> None:
         raise NotImplementedError("foreachPartition() is not implemented.")
 
-    def toLocalIterator(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("toLocalIterator() is not implemented.")
+    def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]:
+        from pyspark.sql.connect.conversion import ArrowTableToRowsConversion
+
+        if self._plan is None:
+            raise Exception("Cannot collect on empty plan.")
+        if self._session is None:
+            raise Exception("Cannot collect on empty session.")
+        query = self._plan.to_proto(self._session.client)
+
+        schema: Optional[StructType] = None
+        for schema_or_table in self._session.client.to_table_as_iterator(query):
+            if isinstance(schema_or_table, StructType):
+                assert schema is None
+                schema = schema_or_table
+            else:
+                assert isinstance(schema_or_table, pa.Table)
+                table = schema_or_table
+                if schema is None:
+                    schema = from_arrow_schema(table.schema)
+                yield from ArrowTableToRowsConversion.convert(table, schema)
+
+    toLocalIterator.__doc__ = PySparkDataFrame.toLocalIterator.__doc__
 
     def checkpoint(self, *args: Any, **kwargs: Any) -> None:
         raise NotImplementedError("checkpoint() is not implemented.")
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 4bc32c41a09..cc5d264bd34 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1224,10 +1224,16 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
 
         .. versionadded:: 2.0.0
 
+        .. versionchanged:: 3.4.0
+            Supports Spark Connect.
+
         Parameters
         ----------
         prefetchPartitions : bool, optional
-            If Spark should pre-fetch the next partition  before it is needed.
+            If Spark should pre-fetch the next partition before it is needed.
+
+            .. versionchanged:: 3.4.0
+                This argument does not take effect for Spark Connect.
 
         Returns
         -------
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 79c8dba537c..ef99760e99d 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2832,7 +2832,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             "withWatermark",
             "foreach",
             "foreachPartition",
-            "toLocalIterator",
             "checkpoint",
             "localCheckpoint",
             "_repr_html_",
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index ae812b4ca55..72a97a2a65c 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -67,20 +67,8 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
     def test_toDF_with_schema_string(self):
         super().test_toDF_with_schema_string()
 
-    # TODO(SPARK-41876): Implement DataFrame `toLocalIterator`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_to_local_iterator(self):
-        super().test_to_local_iterator()
-
-    # TODO(SPARK-41876): Implement DataFrame `toLocalIterator`
-    @unittest.skip("Fails in Spark Connect, should enable.")
     def test_to_local_iterator_not_fully_consumed(self):
-        super().test_to_local_iterator_not_fully_consumed()
-
-    # TODO(SPARK-41876): Implement DataFrame `toLocalIterator`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_to_local_iterator_prefetch(self):
-        super().test_to_local_iterator_prefetch()
+        self.check_to_local_iterator_not_fully_consumed()
 
     def test_to_pandas_for_array_of_struct(self):
         # Spark Connect's implementation is based on Arrow.
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index 0ce32ec4abf..626c282bbb2 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1508,20 +1508,23 @@ class DataFrameTestsMixin:
         self.assertEqual(expected, list(it))
 
     def test_to_local_iterator_not_fully_consumed(self):
+        with QuietTest(self.sc):
+            self.check_to_local_iterator_not_fully_consumed()
+
+    def check_to_local_iterator_not_fully_consumed(self):
         # SPARK-23961: toLocalIterator throws exception when not fully consumed
         # Create a DataFrame large enough so that write to socket will eventually block
         df = self.spark.range(1 << 20, numPartitions=2)
         it = df.toLocalIterator()
         self.assertEqual(df.take(1)[0], next(it))
-        with QuietTest(self.sc):
-            it = None  # remove iterator from scope, socket is closed when cleaned up
-            # Make sure normal df operations still work
-            result = []
-            for i, row in enumerate(df.toLocalIterator()):
-                result.append(row)
-                if i == 7:
-                    break
-            self.assertEqual(df.take(8), result)
+        it = None  # remove iterator from scope, socket is closed when cleaned up
+        # Make sure normal df operations still work
+        result = []
+        for i, row in enumerate(df.toLocalIterator()):
+            result.append(row)
+            if i == 7:
+                break
+        self.assertEqual(df.take(8), result)
 
     def test_same_semantics_error(self):
         with QuietTest(self.sc):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org