You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/25 03:58:11 UTC

[spark] branch master updated: [SPARK-41216][CONNECT][PYTHON] Implement `DataFrame.{isLocal, isStreaming, printSchema, inputFiles}`

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b84ddd5e71a [SPARK-41216][CONNECT][PYTHON] Implement `DataFrame.{isLocal, isStreaming, printSchema, inputFiles}`
b84ddd5e71a is described below

commit b84ddd5e71acc9ae3facdf47148becef3861d11d
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Fri Nov 25 11:57:33 2022 +0800

    [SPARK-41216][CONNECT][PYTHON] Implement `DataFrame.{isLocal, isStreaming, printSchema, inputFiles}`
    
    ### What changes were proposed in this pull request?
    ~~1, Make `AnalyzePlan` support specified multiple analysis tasks, that is, we can get `isLocal`, `schema`, `semanticHash` together in single RPC if we want.~~
    2, Implement following APIs
    
    - isLocal
    - isStreaming
    - printSchema
    - ~~semanticHash~~
    - ~~sameSemantics~~
    - inputFiles
    
    ### Why are the changes needed?
    for API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new APIs
    
    ### How was this patch tested?
    added UTs
    
    Closes #38742 from zhengruifeng/connect_df_print_schema.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../src/main/protobuf/spark/connect/base.proto     | 13 ++++
 .../sql/connect/service/SparkConnectService.scala  | 11 ++-
 .../connect/planner/SparkConnectServiceSuite.scala |  9 +++
 python/pyspark/sql/connect/client.py               | 23 ++++++-
 python/pyspark/sql/connect/dataframe.py            | 78 +++++++++++++++++++++-
 python/pyspark/sql/connect/proto/base_pb2.py       | 36 +++++-----
 python/pyspark/sql/connect/proto/base_pb2.pyi      | 36 +++++++++-
 .../sql/tests/connect/test_connect_basic.py        | 40 +++++++++++
 8 files changed, 221 insertions(+), 25 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto
index d6dac4854ef..5f9a4411ecd 100644
--- a/connector/connect/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/base.proto
@@ -112,6 +112,19 @@ message AnalyzePlanResponse {
 
   // The extended explain string as produced by Spark.
   string explain_string = 3;
+
+  // Get the tree string of the schema.
+  string tree_string = 4;
+
+  // Whether the 'collect' and 'take' methods can be run locally.
+  bool is_local = 5;
+
+  // Whether this plan contains one or more sources that continuously
+  // return data as it arrives.
+  bool is_streaming = 6;
+
+  // A best-effort snapshot of the files that compose this Dataset
+  repeated string input_files = 7;
 }
 
 // A request to be executed by the service.
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 0c7a2ad2690..3046c8eebfc 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.connect.service
 
 import java.util.concurrent.TimeUnit
 
+import scala.collection.JavaConverters._
+
 import com.google.common.base.Ticker
 import com.google.common.cache.CacheBuilder
 import io.grpc.{Server, Status}
@@ -127,10 +129,13 @@ class SparkConnectService(debug: Boolean)
     val ds = Dataset.ofRows(session, logicalPlan)
     val explainString = ds.queryExecution.explainString(explainMode)
 
-    val response = proto.AnalyzePlanResponse
-      .newBuilder()
-      .setExplainString(explainString)
+    val response = proto.AnalyzePlanResponse.newBuilder()
     response.setSchema(DataTypeProtoConverter.toConnectProtoType(ds.schema))
+    response.setExplainString(explainString)
+    response.setTreeString(ds.schema.treeString)
+    response.setIsLocal(ds.isLocal)
+    response.setIsStreaming(ds.isStreaming)
+    response.addAllInputFiles(ds.inputFiles.toSeq.asJava)
   }
 }
 
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 6ca3c2430c4..e5cd84fb504 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -65,6 +65,15 @@ class SparkConnectServiceSuite extends SharedSparkSession {
       assert(
         schema.getFields(1).getName == "col2"
           && schema.getFields(1).getDataType.getKindCase == proto.DataType.KindCase.STRING)
+
+      assert(!response.getIsLocal)
+      assert(!response.getIsLocal)
+
+      assert(response.getTreeString.contains("root"))
+      assert(response.getTreeString.contains("|-- col1: integer (nullable = true)"))
+      assert(response.getTreeString.contains("|-- col2: string (nullable = true)"))
+
+      assert(response.getInputFilesCount === 0)
     }
   }
 
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index eb2e2227fb9..b41df12c357 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -266,13 +266,32 @@ class PlanMetrics:
 
 
 class AnalyzeResult:
-    def __init__(self, schema: pb2.DataType, explain: str):
+    def __init__(
+        self,
+        schema: pb2.DataType,
+        explain: str,
+        tree_string: str,
+        is_local: bool,
+        is_streaming: bool,
+        input_files: List[str],
+    ):
         self.schema = schema
         self.explain_string = explain
+        self.tree_string = tree_string
+        self.is_local = is_local
+        self.is_streaming = is_streaming
+        self.input_files = input_files
 
     @classmethod
     def fromProto(cls, pb: Any) -> "AnalyzeResult":
-        return AnalyzeResult(pb.schema, pb.explain_string)
+        return AnalyzeResult(
+            pb.schema,
+            pb.explain_string,
+            pb.tree_string,
+            pb.is_local,
+            pb.is_streaming,
+            pb.input_files,
+        )
 
 
 class RemoteSparkSession(object):
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index e3a7e8c7335..23340e46165 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -104,7 +104,6 @@ class DataFrame(object):
         """Creates a new data frame"""
         self._schema = schema
         self._plan: Optional[plan.LogicalPlan] = None
-        self._cache: Dict[str, Any] = {}
         self._session: "RemoteSparkSession" = session
 
     def __repr__(self) -> str:
@@ -822,6 +821,83 @@ class DataFrame(object):
         else:
             return self._schema
 
+    @property
+    def isLocal(self) -> bool:
+        """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally
+        (without any Spark executors).
+
+        .. versionadded:: 3.4.0
+
+        Returns
+        -------
+        bool
+        """
+        if self._plan is None:
+            raise Exception("Cannot analyze on empty plan.")
+        query = self._plan.to_proto(self._session)
+        return self._session._analyze(query).is_local
+
+    @property
+    def isStreaming(self) -> bool:
+        """Returns ``True`` if this :class:`DataFrame` contains one or more sources that
+        continuously return data as it arrives. A :class:`DataFrame` that reads data from a
+        streaming source must be executed as a :class:`StreamingQuery` using the :func:`start`
+        method in :class:`DataStreamWriter`.  Methods that return a single answer, (e.g.,
+        :func:`count` or :func:`collect`) will throw an :class:`AnalysisException` when there
+        is a streaming source present.
+
+        .. versionadded:: 3.4.0
+
+        Notes
+        -----
+        This API is evolving.
+
+        Returns
+        -------
+        bool
+            Whether it's streaming DataFrame or not.
+        """
+        if self._plan is None:
+            raise Exception("Cannot analyze on empty plan.")
+        query = self._plan.to_proto(self._session)
+        return self._session._analyze(query).is_streaming
+
+    def _tree_string(self) -> str:
+        if self._plan is None:
+            raise Exception("Cannot analyze on empty plan.")
+        query = self._plan.to_proto(self._session)
+        return self._session._analyze(query).tree_string
+
+    def printSchema(self) -> None:
+        """Prints out the schema in the tree format.
+
+        .. versionadded:: 3.4.0
+
+        Returns
+        -------
+        None
+        """
+        print(self._tree_string())
+
+    def inputFiles(self) -> List[str]:
+        """
+        Returns a best-effort snapshot of the files that compose this :class:`DataFrame`.
+        This method simply asks each constituent BaseRelation for its respective files and
+        takes the union of all results. Depending on the source relations, this may not find
+        all input files. Duplicates are removed.
+
+        .. versionadded:: 3.4.0
+
+        Returns
+        -------
+        list
+            List of file paths.
+        """
+        if self._plan is None:
+            raise Exception("Cannot analyze on empty plan.")
+        query = self._plan.to_proto(self._session)
+        return self._session._analyze(query).input_files
+
     def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) -> "DataFrame":
         """Returns a new :class:`DataFrame`. Concise syntax for chaining custom transformations.
 
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py
index daa1c25cc8f..0d86ce8cd68 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain. [...]
+    b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain. [...]
 )
 
 
@@ -204,21 +204,21 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _ANALYZEPLANREQUEST._serialized_start = 585
     _ANALYZEPLANREQUEST._serialized_end = 842
     _ANALYZEPLANRESPONSE._serialized_start = 845
-    _ANALYZEPLANRESPONSE._serialized_end = 983
-    _EXECUTEPLANREQUEST._serialized_start = 986
-    _EXECUTEPLANREQUEST._serialized_end = 1193
-    _EXECUTEPLANRESPONSE._serialized_start = 1196
-    _EXECUTEPLANRESPONSE._serialized_end = 1979
-    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 1398
-    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 1459
-    _EXECUTEPLANRESPONSE_METRICS._serialized_start = 1462
-    _EXECUTEPLANRESPONSE_METRICS._serialized_end = 1979
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 1557
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 1889
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1766
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1889
-    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 1891
-    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 1979
-    _SPARKCONNECTSERVICE._serialized_start = 1982
-    _SPARKCONNECTSERVICE._serialized_end = 2181
+    _ANALYZEPLANRESPONSE._serialized_end = 1111
+    _EXECUTEPLANREQUEST._serialized_start = 1114
+    _EXECUTEPLANREQUEST._serialized_end = 1321
+    _EXECUTEPLANRESPONSE._serialized_start = 1324
+    _EXECUTEPLANRESPONSE._serialized_end = 2107
+    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 1526
+    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 1587
+    _EXECUTEPLANRESPONSE_METRICS._serialized_start = 1590
+    _EXECUTEPLANRESPONSE_METRICS._serialized_end = 2107
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 1685
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 2017
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1894
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 2017
+    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 2019
+    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 2107
+    _SPARKCONNECTSERVICE._serialized_start = 2110
+    _SPARKCONNECTSERVICE._serialized_end = 2309
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 64bb51d4c0b..ea82aaf21e2 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -283,17 +283,38 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
     CLIENT_ID_FIELD_NUMBER: builtins.int
     SCHEMA_FIELD_NUMBER: builtins.int
     EXPLAIN_STRING_FIELD_NUMBER: builtins.int
+    TREE_STRING_FIELD_NUMBER: builtins.int
+    IS_LOCAL_FIELD_NUMBER: builtins.int
+    IS_STREAMING_FIELD_NUMBER: builtins.int
+    INPUT_FILES_FIELD_NUMBER: builtins.int
     client_id: builtins.str
     @property
     def schema(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
     explain_string: builtins.str
     """The extended explain string as produced by Spark."""
+    tree_string: builtins.str
+    """Get the tree string of the schema."""
+    is_local: builtins.bool
+    """Whether the 'collect' and 'take' methods can be run locally."""
+    is_streaming: builtins.bool
+    """Whether this plan contains one or more sources that continuously
+    return data as it arrives.
+    """
+    @property
+    def input_files(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+        """A best-effort snapshot of the files that compose this Dataset"""
     def __init__(
         self,
         *,
         client_id: builtins.str = ...,
         schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
         explain_string: builtins.str = ...,
+        tree_string: builtins.str = ...,
+        is_local: builtins.bool = ...,
+        is_streaming: builtins.bool = ...,
+        input_files: collections.abc.Iterable[builtins.str] | None = ...,
     ) -> None: ...
     def HasField(
         self, field_name: typing_extensions.Literal["schema", b"schema"]
@@ -301,7 +322,20 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
-            "client_id", b"client_id", "explain_string", b"explain_string", "schema", b"schema"
+            "client_id",
+            b"client_id",
+            "explain_string",
+            b"explain_string",
+            "input_files",
+            b"input_files",
+            "is_local",
+            b"is_local",
+            "is_streaming",
+            b"is_streaming",
+            "schema",
+            b"schema",
+            "tree_string",
+            b"tree_string",
         ],
     ) -> None: ...
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 529d53ed7bc..845d6ead567 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -211,6 +211,46 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             self.connect.sql(query).schema.__repr__(),
         )
 
+    def test_print_schema(self):
+        # SPARK-41216: Test print schema
+        tree_str = self.connect.sql("SELECT 1 AS X, 2 AS Y")._tree_string()
+        # root
+        #  |-- X: integer (nullable = false)
+        #  |-- Y: integer (nullable = false)
+        expected = "root\n |-- X: integer (nullable = false)\n |-- Y: integer (nullable = false)\n"
+        self.assertEqual(tree_str, expected)
+
+    def test_is_local(self):
+        # SPARK-41216: Test is local
+        self.assertTrue(self.connect.sql("SHOW DATABASES").isLocal)
+        self.assertFalse(self.connect.read.table(self.tbl_name).isLocal)
+
+    def test_is_streaming(self):
+        # SPARK-41216: Test is streaming
+        self.assertFalse(self.connect.read.table(self.tbl_name).isStreaming)
+        self.assertFalse(self.connect.sql("SELECT 1 AS X LIMIT 0").isStreaming)
+
+    def test_input_files(self):
+        # SPARK-41216: Test input files
+        tmpPath = tempfile.mkdtemp()
+        shutil.rmtree(tmpPath)
+        try:
+            self.df_text.write.text(tmpPath)
+
+            input_files_list1 = (
+                self.spark.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles()
+            )
+            input_files_list2 = (
+                self.connect.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles()
+            )
+
+            self.assertTrue(len(input_files_list1) > 0)
+            self.assertEqual(len(input_files_list1), len(input_files_list2))
+            for file_path in input_files_list2:
+                self.assertTrue(file_path in input_files_list1)
+        finally:
+            shutil.rmtree(tmpPath)
+
     def test_simple_binary_expressions(self):
         """Test complex expression"""
         df = self.connect.read.table(self.tbl_name)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org