You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/03/04 05:46:48 UTC

[spark] branch branch-3.4 updated: [SPARK-42556][CONNECT] Dataset.colregex should link a plan_id when it only matches a single column

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 948873420e9 [SPARK-42556][CONNECT] Dataset.colregex should link a plan_id when it only matches a single column
948873420e9 is described below

commit 948873420e99f728e18a25890ad375cdd39afe59
Author: Jiaan Geng <be...@163.com>
AuthorDate: Sat Mar 4 13:45:56 2023 +0800

    [SPARK-42556][CONNECT] Dataset.colregex should link a plan_id when it only matches a single column
    
    ### What changes were proposed in this pull request?
    When colregex returns a single column it should link the plans plan_id. For reference here is the non-connect Dataset code that does this:
    https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L1512
    This also needs to be fixed for the Python client.
    
    ### Why are the changes needed?
    Let the `UnresolvedAttribute` link plan_id if it is exist.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    New feature.
    
    ### How was this patch tested?
    New test cases.
    
    Closes #40265 from beliefer/SPARK-42556.
    
    Authored-by: Jiaan Geng <be...@163.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
    (cherry picked from commit c99a632fea74136964b27b28563115fe2d7667b3)
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  21 +++++++-----
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |   5 +++
 .../main/protobuf/spark/connect/expressions.proto  |   3 ++
 .../resources/query-tests/queries/colRegex.json    |   3 +-
 .../query-tests/queries/colRegex.proto.bin         | Bin 60 -> 62 bytes
 .../sql/connect/planner/SparkConnectPlanner.scala  |   6 +++-
 python/pyspark/sql/connect/dataframe.py            |   5 ++-
 python/pyspark/sql/connect/expressions.py          |  10 +++---
 .../pyspark/sql/connect/proto/expressions_pb2.py   |  38 ++++++++++-----------
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  |  16 ++++++++-
 10 files changed, 72 insertions(+), 35 deletions(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 1cd3c541950..e264f1c0c0c 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -917,6 +917,13 @@ class Dataset[T] private[sql] (
         .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava)
   }
 
+  private def getPlanId: Option[Long] =
+    if (plan.getRoot.hasCommon && plan.getRoot.getCommon.hasPlanId) {
+      Option(plan.getRoot.getCommon.getPlanId)
+    } else {
+      None
+    }
+
   /**
    * Selects column based on the column name and returns it as a [[Column]].
    *
@@ -927,12 +934,7 @@ class Dataset[T] private[sql] (
    * @since 3.4.0
    */
   def col(colName: String): Column = {
-    val planId = if (plan.getRoot.hasCommon && plan.getRoot.getCommon.hasPlanId) {
-      Option(plan.getRoot.getCommon.getPlanId)
-    } else {
-      None
-    }
-    Column.apply(colName, planId)
+    Column.apply(colName, getPlanId)
   }
 
   /**
@@ -940,8 +942,11 @@ class Dataset[T] private[sql] (
    * @group untypedrel
    * @since 3.4.0
    */
-  def colRegex(colName: String): Column = Column { builder =>
-    builder.getUnresolvedRegexBuilder.setColName(colName)
+  def colRegex(colName: String): Column = {
+    Column { builder =>
+      val unresolvedRegexBuilder = builder.getUnresolvedRegexBuilder.setColName(colName)
+      getPlanId.foreach(unresolvedRegexBuilder.setPlanId)
+    }
   }
 
   /**
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index a3f1de55892..5c35ef448be 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -493,6 +493,11 @@ class ClientE2ETestSuite extends RemoteSparkSession {
     val right = spark.range(100).select(col("id"), rand(12).as("a"))
     val joined = left.join(right, left("id") === right("id")).select(left("id"), right("a"))
     assert(joined.schema.catalogString === "struct<id:bigint,a:double>")
+
+    val joined2 = left
+      .join(right, left.colRegex("id") === right.colRegex("id"))
+      .select(left("id"), right("a"))
+    assert(joined2.schema.catalogString === "struct<id:bigint,a:double>")
   }
 
   test("test temp view") {
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 1929d9cdca3..e37a13ee959 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -242,6 +242,9 @@ message Expression {
   message UnresolvedRegex {
     // (Required) The column name used to extract column with regex.
     string col_name = 1;
+
+    // (Optional) The id of corresponding connect plan.
+    optional int64 plan_id = 2;
   }
 
   // Extracts a value or values from an Expression
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/colRegex.json b/connector/connect/common/src/test/resources/query-tests/queries/colRegex.json
index 56021594c88..3a7508b63a9 100644
--- a/connector/connect/common/src/test/resources/query-tests/queries/colRegex.json
+++ b/connector/connect/common/src/test/resources/query-tests/queries/colRegex.json
@@ -13,7 +13,8 @@
     },
     "expressions": [{
       "unresolvedRegex": {
-        "colName": "`a|id`"
+        "colName": "`a|id`",
+        "planId": "1"
       }
     }]
   }
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin
index 2f3ab10233e..ce518b35fbd 100644
Binary files a/connector/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin differ
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index c8b1b3125f9..76a4c7faaa2 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1179,7 +1179,11 @@ class SparkConnectPlanner(val session: SparkSession) {
       case ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex) =>
         UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive)
       case _ =>
-        UnresolvedAttribute.quotedString(regex.getColName)
+        val expr = UnresolvedAttribute.quotedString(regex.getColName)
+        if (regex.hasPlanId) {
+          expr.setTagValue(LogicalPlan.PLAN_ID_TAG, regex.getPlanId)
+        }
+        expr
     }
   }
 
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 955186787a6..471dbf89582 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -153,7 +153,10 @@ class DataFrame:
                 error_class="NOT_STR",
                 message_parameters={"arg_name": "colName", "arg_type": type(colName).__name__},
             )
-        return Column(UnresolvedRegex(colName))
+        if self._plan is not None:
+            return Column(UnresolvedRegex(colName, self._plan._plan_id))
+        else:
+            return Column(UnresolvedRegex(colName))
 
     colRegex.__doc__ = PySparkDataFrame.colRegex.__doc__
 
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index f3c9e2c70c4..2b1901167c1 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -642,18 +642,20 @@ class UnresolvedExtractValue(Expression):
 
 
 class UnresolvedRegex(Expression):
-    def __init__(
-        self,
-        col_name: str,
-    ) -> None:
+    def __init__(self, col_name: str, plan_id: Optional[int] = None) -> None:
         super().__init__()
 
         assert isinstance(col_name, str)
         self.col_name = col_name
 
+        assert plan_id is None or isinstance(plan_id, int)
+        self._plan_id = plan_id
+
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
         expr = proto.Expression()
         expr.unresolved_regex.col_name = self.col_name
+        if self._plan_id is not None:
+            expr.unresolved_regex.plan_id = self._plan_id
         return expr
 
     def __repr__(self) -> str:
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 891be5ea9ea..6e515235c7d 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xbc%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...]
+    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xe6%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...]
 )
 
 
@@ -300,7 +300,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
     _EXPRESSION._serialized_start = 105
-    _EXPRESSION._serialized_end = 4901
+    _EXPRESSION._serialized_end = 4943
     _EXPRESSION_WINDOW._serialized_start = 1475
     _EXPRESSION_WINDOW._serialized_end = 2258
     _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1765
@@ -332,21 +332,21 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4088
     _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4170
     _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4172
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4216
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4219
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4351
-    _EXPRESSION_UPDATEFIELDS._serialized_start = 4354
-    _EXPRESSION_UPDATEFIELDS._serialized_end = 4541
-    _EXPRESSION_ALIAS._serialized_start = 4543
-    _EXPRESSION_ALIAS._serialized_end = 4663
-    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4666
-    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4824
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4826
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4888
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4904
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5215
-    _PYTHONUDF._serialized_start = 5218
-    _PYTHONUDF._serialized_end = 5348
-    _SCALARSCALAUDF._serialized_start = 5351
-    _SCALARSCALAUDF._serialized_end = 5535
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4258
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4261
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4393
+    _EXPRESSION_UPDATEFIELDS._serialized_start = 4396
+    _EXPRESSION_UPDATEFIELDS._serialized_end = 4583
+    _EXPRESSION_ALIAS._serialized_start = 4585
+    _EXPRESSION_ALIAS._serialized_end = 4705
+    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4708
+    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4866
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4868
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4930
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4946
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5257
+    _PYTHONUDF._serialized_start = 5260
+    _PYTHONUDF._serialized_end = 5390
+    _SCALARSCALAUDF._serialized_start = 5393
+    _SCALARSCALAUDF._serialized_end = 5577
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 88b1fd8ef7e..996de7fef2d 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -753,16 +753,30 @@ class Expression(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
         COL_NAME_FIELD_NUMBER: builtins.int
+        PLAN_ID_FIELD_NUMBER: builtins.int
         col_name: builtins.str
         """(Required) The column name used to extract column with regex."""
+        plan_id: builtins.int
+        """(Optional) The id of corresponding connect plan."""
         def __init__(
             self,
             *,
             col_name: builtins.str = ...,
+            plan_id: builtins.int | None = ...,
         ) -> None: ...
+        def HasField(
+            self,
+            field_name: typing_extensions.Literal["_plan_id", b"_plan_id", "plan_id", b"plan_id"],
+        ) -> builtins.bool: ...
         def ClearField(
-            self, field_name: typing_extensions.Literal["col_name", b"col_name"]
+            self,
+            field_name: typing_extensions.Literal[
+                "_plan_id", b"_plan_id", "col_name", b"col_name", "plan_id", b"plan_id"
+            ],
         ) -> None: ...
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_plan_id", b"_plan_id"]
+        ) -> typing_extensions.Literal["plan_id"] | None: ...
 
     class UnresolvedExtractValue(google.protobuf.message.Message):
         """Extracts a value or values from an Expression"""


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org