You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/03/04 05:46:41 UTC
[spark] branch master updated: [SPARK-42556][CONNECT] Dataset.colregex should link a plan_id when it only matches a single column
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c99a632fea7 [SPARK-42556][CONNECT] Dataset.colregex should link a plan_id when it only matches a single column
c99a632fea7 is described below
commit c99a632fea74136964b27b28563115fe2d7667b3
Author: Jiaan Geng <be...@163.com>
AuthorDate: Sat Mar 4 13:45:56 2023 +0800
[SPARK-42556][CONNECT] Dataset.colregex should link a plan_id when it only matches a single column
### What changes were proposed in this pull request?
When colregex returns a single column it should link the plans plan_id. For reference here is the non-connect Dataset code that does this:
https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L1512
This also needs to be fixed for the Python client.
### Why are the changes needed?
Let the `UnresolvedAttribute` link plan_id if it is exist.
### Does this PR introduce _any_ user-facing change?
'No'.
New feature.
### How was this patch tested?
New test cases.
Closes #40265 from beliefer/SPARK-42556.
Authored-by: Jiaan Geng <be...@163.com>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../main/scala/org/apache/spark/sql/Dataset.scala | 21 +++++++-----
.../org/apache/spark/sql/ClientE2ETestSuite.scala | 5 +++
.../main/protobuf/spark/connect/expressions.proto | 3 ++
.../resources/query-tests/queries/colRegex.json | 3 +-
.../query-tests/queries/colRegex.proto.bin | Bin 60 -> 62 bytes
.../sql/connect/planner/SparkConnectPlanner.scala | 6 +++-
python/pyspark/sql/connect/dataframe.py | 5 ++-
python/pyspark/sql/connect/expressions.py | 10 +++---
.../pyspark/sql/connect/proto/expressions_pb2.py | 38 ++++++++++-----------
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 16 ++++++++-
10 files changed, 72 insertions(+), 35 deletions(-)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 1cd3c541950..e264f1c0c0c 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -917,6 +917,13 @@ class Dataset[T] private[sql] (
.addAllParameters(parameters.map(p => functions.lit(p).expr).asJava)
}
+ private def getPlanId: Option[Long] =
+ if (plan.getRoot.hasCommon && plan.getRoot.getCommon.hasPlanId) {
+ Option(plan.getRoot.getCommon.getPlanId)
+ } else {
+ None
+ }
+
/**
* Selects column based on the column name and returns it as a [[Column]].
*
@@ -927,12 +934,7 @@ class Dataset[T] private[sql] (
* @since 3.4.0
*/
def col(colName: String): Column = {
- val planId = if (plan.getRoot.hasCommon && plan.getRoot.getCommon.hasPlanId) {
- Option(plan.getRoot.getCommon.getPlanId)
- } else {
- None
- }
- Column.apply(colName, planId)
+ Column.apply(colName, getPlanId)
}
/**
@@ -940,8 +942,11 @@ class Dataset[T] private[sql] (
* @group untypedrel
* @since 3.4.0
*/
- def colRegex(colName: String): Column = Column { builder =>
- builder.getUnresolvedRegexBuilder.setColName(colName)
+ def colRegex(colName: String): Column = {
+ Column { builder =>
+ val unresolvedRegexBuilder = builder.getUnresolvedRegexBuilder.setColName(colName)
+ getPlanId.foreach(unresolvedRegexBuilder.setPlanId)
+ }
}
/**
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index a3f1de55892..5c35ef448be 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -493,6 +493,11 @@ class ClientE2ETestSuite extends RemoteSparkSession {
val right = spark.range(100).select(col("id"), rand(12).as("a"))
val joined = left.join(right, left("id") === right("id")).select(left("id"), right("a"))
assert(joined.schema.catalogString === "struct<id:bigint,a:double>")
+
+ val joined2 = left
+ .join(right, left.colRegex("id") === right.colRegex("id"))
+ .select(left("id"), right("a"))
+ assert(joined2.schema.catalogString === "struct<id:bigint,a:double>")
}
test("test temp view") {
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 1929d9cdca3..e37a13ee959 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -242,6 +242,9 @@ message Expression {
message UnresolvedRegex {
// (Required) The column name used to extract column with regex.
string col_name = 1;
+
+ // (Optional) The id of corresponding connect plan.
+ optional int64 plan_id = 2;
}
// Extracts a value or values from an Expression
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/colRegex.json b/connector/connect/common/src/test/resources/query-tests/queries/colRegex.json
index 56021594c88..3a7508b63a9 100644
--- a/connector/connect/common/src/test/resources/query-tests/queries/colRegex.json
+++ b/connector/connect/common/src/test/resources/query-tests/queries/colRegex.json
@@ -13,7 +13,8 @@
},
"expressions": [{
"unresolvedRegex": {
- "colName": "`a|id`"
+ "colName": "`a|id`",
+ "planId": "1"
}
}]
}
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin
index 2f3ab10233e..ce518b35fbd 100644
Binary files a/connector/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin differ
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index c8b1b3125f9..76a4c7faaa2 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1179,7 +1179,11 @@ class SparkConnectPlanner(val session: SparkSession) {
case ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex) =>
UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive)
case _ =>
- UnresolvedAttribute.quotedString(regex.getColName)
+ val expr = UnresolvedAttribute.quotedString(regex.getColName)
+ if (regex.hasPlanId) {
+ expr.setTagValue(LogicalPlan.PLAN_ID_TAG, regex.getPlanId)
+ }
+ expr
}
}
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 955186787a6..471dbf89582 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -153,7 +153,10 @@ class DataFrame:
error_class="NOT_STR",
message_parameters={"arg_name": "colName", "arg_type": type(colName).__name__},
)
- return Column(UnresolvedRegex(colName))
+ if self._plan is not None:
+ return Column(UnresolvedRegex(colName, self._plan._plan_id))
+ else:
+ return Column(UnresolvedRegex(colName))
colRegex.__doc__ = PySparkDataFrame.colRegex.__doc__
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index f3c9e2c70c4..2b1901167c1 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -642,18 +642,20 @@ class UnresolvedExtractValue(Expression):
class UnresolvedRegex(Expression):
- def __init__(
- self,
- col_name: str,
- ) -> None:
+ def __init__(self, col_name: str, plan_id: Optional[int] = None) -> None:
super().__init__()
assert isinstance(col_name, str)
self.col_name = col_name
+ assert plan_id is None or isinstance(plan_id, int)
+ self._plan_id = plan_id
+
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = proto.Expression()
expr.unresolved_regex.col_name = self.col_name
+ if self._plan_id is not None:
+ expr.unresolved_regex.plan_id = self._plan_id
return expr
def __repr__(self) -> str:
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 891be5ea9ea..6e515235c7d 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xbc%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...]
+ b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xe6%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...]
)
@@ -300,7 +300,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
_EXPRESSION._serialized_start = 105
- _EXPRESSION._serialized_end = 4901
+ _EXPRESSION._serialized_end = 4943
_EXPRESSION_WINDOW._serialized_start = 1475
_EXPRESSION_WINDOW._serialized_end = 2258
_EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1765
@@ -332,21 +332,21 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4088
_EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4170
_EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4172
- _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4216
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4219
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4351
- _EXPRESSION_UPDATEFIELDS._serialized_start = 4354
- _EXPRESSION_UPDATEFIELDS._serialized_end = 4541
- _EXPRESSION_ALIAS._serialized_start = 4543
- _EXPRESSION_ALIAS._serialized_end = 4663
- _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4666
- _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4824
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4826
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4888
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4904
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5215
- _PYTHONUDF._serialized_start = 5218
- _PYTHONUDF._serialized_end = 5348
- _SCALARSCALAUDF._serialized_start = 5351
- _SCALARSCALAUDF._serialized_end = 5535
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4258
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4261
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4393
+ _EXPRESSION_UPDATEFIELDS._serialized_start = 4396
+ _EXPRESSION_UPDATEFIELDS._serialized_end = 4583
+ _EXPRESSION_ALIAS._serialized_start = 4585
+ _EXPRESSION_ALIAS._serialized_end = 4705
+ _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4708
+ _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4866
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4868
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4930
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4946
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5257
+ _PYTHONUDF._serialized_start = 5260
+ _PYTHONUDF._serialized_end = 5390
+ _SCALARSCALAUDF._serialized_start = 5393
+ _SCALARSCALAUDF._serialized_end = 5577
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 88b1fd8ef7e..996de7fef2d 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -753,16 +753,30 @@ class Expression(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
COL_NAME_FIELD_NUMBER: builtins.int
+ PLAN_ID_FIELD_NUMBER: builtins.int
col_name: builtins.str
"""(Required) The column name used to extract column with regex."""
+ plan_id: builtins.int
+ """(Optional) The id of corresponding connect plan."""
def __init__(
self,
*,
col_name: builtins.str = ...,
+ plan_id: builtins.int | None = ...,
) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal["_plan_id", b"_plan_id", "plan_id", b"plan_id"],
+ ) -> builtins.bool: ...
def ClearField(
- self, field_name: typing_extensions.Literal["col_name", b"col_name"]
+ self,
+ field_name: typing_extensions.Literal[
+ "_plan_id", b"_plan_id", "col_name", b"col_name", "plan_id", b"plan_id"
+ ],
) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_plan_id", b"_plan_id"]
+ ) -> typing_extensions.Literal["plan_id"] | None: ...
class UnresolvedExtractValue(google.protobuf.message.Message):
"""Extracts a value or values from an Expression"""
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org