You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2022/11/25 03:58:11 UTC
[spark] branch master updated: [SPARK-41216][CONNECT][PYTHON] Implement `DataFrame.{isLocal, isStreaming, printSchema, inputFiles}`
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b84ddd5e71a [SPARK-41216][CONNECT][PYTHON] Implement `DataFrame.{isLocal, isStreaming, printSchema, inputFiles}`
b84ddd5e71a is described below
commit b84ddd5e71acc9ae3facdf47148becef3861d11d
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Fri Nov 25 11:57:33 2022 +0800
[SPARK-41216][CONNECT][PYTHON] Implement `DataFrame.{isLocal, isStreaming, printSchema, inputFiles}`
### What changes were proposed in this pull request?
~~1, Make `AnalyzePlan` support specified multiple analysis tasks, that is, we can get `isLocal`, `schema`, `semanticHash` together in single RPC if we want.~~
2, Implement following APIs
- isLocal
- isStreaming
- printSchema
- ~~semanticHash~~
- ~~sameSemantics~~
- inputFiles
### Why are the changes needed?
for API coverage
### Does this PR introduce _any_ user-facing change?
yes, new APIs
### How was this patch tested?
added UTs
Closes #38742 from zhengruifeng/connect_df_print_schema.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../src/main/protobuf/spark/connect/base.proto | 13 ++++
.../sql/connect/service/SparkConnectService.scala | 11 ++-
.../connect/planner/SparkConnectServiceSuite.scala | 9 +++
python/pyspark/sql/connect/client.py | 23 ++++++-
python/pyspark/sql/connect/dataframe.py | 78 +++++++++++++++++++++-
python/pyspark/sql/connect/proto/base_pb2.py | 36 +++++-----
python/pyspark/sql/connect/proto/base_pb2.pyi | 36 +++++++++-
.../sql/tests/connect/test_connect_basic.py | 40 +++++++++++
8 files changed, 221 insertions(+), 25 deletions(-)
diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto
index d6dac4854ef..5f9a4411ecd 100644
--- a/connector/connect/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/base.proto
@@ -112,6 +112,19 @@ message AnalyzePlanResponse {
// The extended explain string as produced by Spark.
string explain_string = 3;
+
+ // Get the tree string of the schema.
+ string tree_string = 4;
+
+ // Whether the 'collect' and 'take' methods can be run locally.
+ bool is_local = 5;
+
+ // Whether this plan contains one or more sources that continuously
+ // return data as it arrives.
+ bool is_streaming = 6;
+
+ // A best-effort snapshot of the files that compose this Dataset
+ repeated string input_files = 7;
}
// A request to be executed by the service.
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 0c7a2ad2690..3046c8eebfc 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.connect.service
import java.util.concurrent.TimeUnit
+import scala.collection.JavaConverters._
+
import com.google.common.base.Ticker
import com.google.common.cache.CacheBuilder
import io.grpc.{Server, Status}
@@ -127,10 +129,13 @@ class SparkConnectService(debug: Boolean)
val ds = Dataset.ofRows(session, logicalPlan)
val explainString = ds.queryExecution.explainString(explainMode)
- val response = proto.AnalyzePlanResponse
- .newBuilder()
- .setExplainString(explainString)
+ val response = proto.AnalyzePlanResponse.newBuilder()
response.setSchema(DataTypeProtoConverter.toConnectProtoType(ds.schema))
+ response.setExplainString(explainString)
+ response.setTreeString(ds.schema.treeString)
+ response.setIsLocal(ds.isLocal)
+ response.setIsStreaming(ds.isStreaming)
+ response.addAllInputFiles(ds.inputFiles.toSeq.asJava)
}
}
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 6ca3c2430c4..e5cd84fb504 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -65,6 +65,15 @@ class SparkConnectServiceSuite extends SharedSparkSession {
assert(
schema.getFields(1).getName == "col2"
&& schema.getFields(1).getDataType.getKindCase == proto.DataType.KindCase.STRING)
+
+ assert(!response.getIsLocal)
+ assert(!response.getIsLocal)
+
+ assert(response.getTreeString.contains("root"))
+ assert(response.getTreeString.contains("|-- col1: integer (nullable = true)"))
+ assert(response.getTreeString.contains("|-- col2: string (nullable = true)"))
+
+ assert(response.getInputFilesCount === 0)
}
}
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index eb2e2227fb9..b41df12c357 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -266,13 +266,32 @@ class PlanMetrics:
class AnalyzeResult:
- def __init__(self, schema: pb2.DataType, explain: str):
+ def __init__(
+ self,
+ schema: pb2.DataType,
+ explain: str,
+ tree_string: str,
+ is_local: bool,
+ is_streaming: bool,
+ input_files: List[str],
+ ):
self.schema = schema
self.explain_string = explain
+ self.tree_string = tree_string
+ self.is_local = is_local
+ self.is_streaming = is_streaming
+ self.input_files = input_files
@classmethod
def fromProto(cls, pb: Any) -> "AnalyzeResult":
- return AnalyzeResult(pb.schema, pb.explain_string)
+ return AnalyzeResult(
+ pb.schema,
+ pb.explain_string,
+ pb.tree_string,
+ pb.is_local,
+ pb.is_streaming,
+ pb.input_files,
+ )
class RemoteSparkSession(object):
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index e3a7e8c7335..23340e46165 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -104,7 +104,6 @@ class DataFrame(object):
"""Creates a new data frame"""
self._schema = schema
self._plan: Optional[plan.LogicalPlan] = None
- self._cache: Dict[str, Any] = {}
self._session: "RemoteSparkSession" = session
def __repr__(self) -> str:
@@ -822,6 +821,83 @@ class DataFrame(object):
else:
return self._schema
+ @property
+ def isLocal(self) -> bool:
+ """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally
+ (without any Spark executors).
+
+ .. versionadded:: 3.4.0
+
+ Returns
+ -------
+ bool
+ """
+ if self._plan is None:
+ raise Exception("Cannot analyze on empty plan.")
+ query = self._plan.to_proto(self._session)
+ return self._session._analyze(query).is_local
+
+ @property
+ def isStreaming(self) -> bool:
+ """Returns ``True`` if this :class:`DataFrame` contains one or more sources that
+ continuously return data as it arrives. A :class:`DataFrame` that reads data from a
+ streaming source must be executed as a :class:`StreamingQuery` using the :func:`start`
+ method in :class:`DataStreamWriter`. Methods that return a single answer, (e.g.,
+ :func:`count` or :func:`collect`) will throw an :class:`AnalysisException` when there
+ is a streaming source present.
+
+ .. versionadded:: 3.4.0
+
+ Notes
+ -----
+ This API is evolving.
+
+ Returns
+ -------
+ bool
+ Whether it's streaming DataFrame or not.
+ """
+ if self._plan is None:
+ raise Exception("Cannot analyze on empty plan.")
+ query = self._plan.to_proto(self._session)
+ return self._session._analyze(query).is_streaming
+
+ def _tree_string(self) -> str:
+ if self._plan is None:
+ raise Exception("Cannot analyze on empty plan.")
+ query = self._plan.to_proto(self._session)
+ return self._session._analyze(query).tree_string
+
+ def printSchema(self) -> None:
+ """Prints out the schema in the tree format.
+
+ .. versionadded:: 3.4.0
+
+ Returns
+ -------
+ None
+ """
+ print(self._tree_string())
+
+ def inputFiles(self) -> List[str]:
+ """
+ Returns a best-effort snapshot of the files that compose this :class:`DataFrame`.
+ This method simply asks each constituent BaseRelation for its respective files and
+ takes the union of all results. Depending on the source relations, this may not find
+ all input files. Duplicates are removed.
+
+ .. versionadded:: 3.4.0
+
+ Returns
+ -------
+ list
+ List of file paths.
+ """
+ if self._plan is None:
+ raise Exception("Cannot analyze on empty plan.")
+ query = self._plan.to_proto(self._session)
+ return self._session._analyze(query).input_files
+
def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) -> "DataFrame":
"""Returns a new :class:`DataFrame`. Concise syntax for chaining custom transformations.
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py
index daa1c25cc8f..0d86ce8cd68 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain. [...]
+ b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain. [...]
)
@@ -204,21 +204,21 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_ANALYZEPLANREQUEST._serialized_start = 585
_ANALYZEPLANREQUEST._serialized_end = 842
_ANALYZEPLANRESPONSE._serialized_start = 845
- _ANALYZEPLANRESPONSE._serialized_end = 983
- _EXECUTEPLANREQUEST._serialized_start = 986
- _EXECUTEPLANREQUEST._serialized_end = 1193
- _EXECUTEPLANRESPONSE._serialized_start = 1196
- _EXECUTEPLANRESPONSE._serialized_end = 1979
- _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 1398
- _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 1459
- _EXECUTEPLANRESPONSE_METRICS._serialized_start = 1462
- _EXECUTEPLANRESPONSE_METRICS._serialized_end = 1979
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 1557
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 1889
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1766
- _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1889
- _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 1891
- _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 1979
- _SPARKCONNECTSERVICE._serialized_start = 1982
- _SPARKCONNECTSERVICE._serialized_end = 2181
+ _ANALYZEPLANRESPONSE._serialized_end = 1111
+ _EXECUTEPLANREQUEST._serialized_start = 1114
+ _EXECUTEPLANREQUEST._serialized_end = 1321
+ _EXECUTEPLANRESPONSE._serialized_start = 1324
+ _EXECUTEPLANRESPONSE._serialized_end = 2107
+ _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 1526
+ _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 1587
+ _EXECUTEPLANRESPONSE_METRICS._serialized_start = 1590
+ _EXECUTEPLANRESPONSE_METRICS._serialized_end = 2107
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 1685
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 2017
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1894
+ _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 2017
+ _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 2019
+ _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 2107
+ _SPARKCONNECTSERVICE._serialized_start = 2110
+ _SPARKCONNECTSERVICE._serialized_end = 2309
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 64bb51d4c0b..ea82aaf21e2 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -283,17 +283,38 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
CLIENT_ID_FIELD_NUMBER: builtins.int
SCHEMA_FIELD_NUMBER: builtins.int
EXPLAIN_STRING_FIELD_NUMBER: builtins.int
+ TREE_STRING_FIELD_NUMBER: builtins.int
+ IS_LOCAL_FIELD_NUMBER: builtins.int
+ IS_STREAMING_FIELD_NUMBER: builtins.int
+ INPUT_FILES_FIELD_NUMBER: builtins.int
client_id: builtins.str
@property
def schema(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
explain_string: builtins.str
"""The extended explain string as produced by Spark."""
+ tree_string: builtins.str
+ """Get the tree string of the schema."""
+ is_local: builtins.bool
+ """Whether the 'collect' and 'take' methods can be run locally."""
+ is_streaming: builtins.bool
+ """Whether this plan contains one or more sources that continuously
+ return data as it arrives.
+ """
+ @property
+ def input_files(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """A best-effort snapshot of the files that compose this Dataset"""
def __init__(
self,
*,
client_id: builtins.str = ...,
schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
explain_string: builtins.str = ...,
+ tree_string: builtins.str = ...,
+ is_local: builtins.bool = ...,
+ is_streaming: builtins.bool = ...,
+ input_files: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["schema", b"schema"]
@@ -301,7 +322,20 @@ class AnalyzePlanResponse(google.protobuf.message.Message):
def ClearField(
self,
field_name: typing_extensions.Literal[
- "client_id", b"client_id", "explain_string", b"explain_string", "schema", b"schema"
+ "client_id",
+ b"client_id",
+ "explain_string",
+ b"explain_string",
+ "input_files",
+ b"input_files",
+ "is_local",
+ b"is_local",
+ "is_streaming",
+ b"is_streaming",
+ "schema",
+ b"schema",
+ "tree_string",
+ b"tree_string",
],
) -> None: ...
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 529d53ed7bc..845d6ead567 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -211,6 +211,46 @@ class SparkConnectTests(SparkConnectSQLTestCase):
self.connect.sql(query).schema.__repr__(),
)
+ def test_print_schema(self):
+ # SPARK-41216: Test print schema
+ tree_str = self.connect.sql("SELECT 1 AS X, 2 AS Y")._tree_string()
+ # root
+ # |-- X: integer (nullable = false)
+ # |-- Y: integer (nullable = false)
+ expected = "root\n |-- X: integer (nullable = false)\n |-- Y: integer (nullable = false)\n"
+ self.assertEqual(tree_str, expected)
+
+ def test_is_local(self):
+ # SPARK-41216: Test is local
+ self.assertTrue(self.connect.sql("SHOW DATABASES").isLocal)
+ self.assertFalse(self.connect.read.table(self.tbl_name).isLocal)
+
+ def test_is_streaming(self):
+ # SPARK-41216: Test is streaming
+ self.assertFalse(self.connect.read.table(self.tbl_name).isStreaming)
+ self.assertFalse(self.connect.sql("SELECT 1 AS X LIMIT 0").isStreaming)
+
+ def test_input_files(self):
+ # SPARK-41216: Test input files
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ try:
+ self.df_text.write.text(tmpPath)
+
+ input_files_list1 = (
+ self.spark.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles()
+ )
+ input_files_list2 = (
+ self.connect.read.format("text").schema("id STRING").load(path=tmpPath).inputFiles()
+ )
+
+ self.assertTrue(len(input_files_list1) > 0)
+ self.assertEqual(len(input_files_list1), len(input_files_list2))
+ for file_path in input_files_list2:
+ self.assertTrue(file_path in input_files_list1)
+ finally:
+ shutil.rmtree(tmpPath)
+
def test_simple_binary_expressions(self):
"""Test complex expression"""
df = self.connect.read.table(self.tbl_name)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org