You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2023/02/07 08:31:52 UTC

[spark] branch branch-3.4 updated: [SPARK-40532][CONNECT] Add Python Version into Python UDF message

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 9e0e4182f62 [SPARK-40532][CONNECT] Add Python Version into Python UDF message
9e0e4182f62 is described below

commit 9e0e4182f62e3dd30aa9d6f57a746a258830e351
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Tue Feb 7 00:31:16 2023 -0800

    [SPARK-40532][CONNECT] Add Python Version into Python UDF message
    
    ### What changes were proposed in this pull request?
    
    This PR adds the Python version from the remote client.
    
    See also:
    
    https://github.com/apache/spark/blob/56c7cf33929d7d42b7d299c0bb7e895963241214/python/pyspark/context.py#L312
    
    https://github.com/apache/spark/blob/603dc5098217d9580f611873165d25392f41cdfe/python/pyspark/worker.py#L682-L691
    
    ### Why are the changes needed?
    
    In order to make sure to run it in the same Python version of executors.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No to end users.
    Yes for the dev. Now other Python versions can be used for Spark Connect.
    
    ### How was this patch tested?
    
    Manually tested.
    
    Closes #39914 from HyukjinKwon/SPARK-40532.
    
    Authored-by: Hyukjin Kwon <gu...@apache.org>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
    (cherry picked from commit e49466a8be617da86b614d0fab49f00b0da0d952)
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../src/main/protobuf/spark/connect/expressions.proto       |  2 ++
 .../spark/sql/connect/planner/SparkConnectPlanner.scala     |  2 +-
 .../sql/connect/messages/ConnectProtoMessagesSuite.scala    |  2 ++
 python/pyspark/sql/connect/expressions.py                   |  5 ++++-
 python/pyspark/sql/connect/proto/expressions_pb2.py         | 10 +++++-----
 python/pyspark/sql/connect/proto/expressions_pb2.pyi        | 13 ++++++++++++-
 python/pyspark/sql/connect/udf.py                           |  2 ++
 7 files changed, 28 insertions(+), 8 deletions(-)

diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 66361883321..8682e1ee27b 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -318,6 +318,8 @@ message PythonUDF {
   int32 eval_type = 2;
   // (Required) The encoded commands of the Python UDF
   bytes command = 3;
+  // (Required) Python version being used in the client.
+  string python_ver = 4;
 }
 
 message ScalarScalaUDF {
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 07a5e5bc156..c8a0860b871 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -894,7 +894,7 @@ class SparkConnectPlanner(val session: SparkSession) {
       // No imported Python libraries
       pythonIncludes = Lists.newArrayList(),
       pythonExec = pythonExec,
-      pythonVer = "3.9", // TODO(SPARK-40532) This needs to be an actual Python version.
+      pythonVer = fun.getPythonVer,
       // Empty broadcast variables
       broadcastVars = Lists.newArrayList(),
       // Null accumulator
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
index 240f6573c7d..09462ce18c2 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
@@ -63,6 +63,7 @@ class ConnectProtoMessagesSuite extends SparkFunSuite {
       .setEvalType(100)
       .setOutputType("\"integer\"")
       .setCommand(ByteString.copyFrom("command".getBytes()))
+      .setPythonVer("3.10")
       .build()
 
     val commonInlineUserDefinedFunctionExpr = proto.Expression
@@ -81,5 +82,6 @@ class ConnectProtoMessagesSuite extends SparkFunSuite {
     assert(fun.getDeterministic == true)
     assert(fun.getArgumentsCount == 1)
     assert(fun.hasPythonUdf == true)
+    assert(pythonUdf.getPythonVer == "3.10")
   }
 }
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index 300a1a006d5..dcd7c5ebba6 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -491,22 +491,25 @@ class PythonUDF:
         output_type: str,
         eval_type: int,
         command: bytes,
+        python_ver: str,
     ) -> None:
         self._output_type = output_type
         self._eval_type = eval_type
         self._command = command
+        self._python_ver = python_ver
 
     def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDF:
         expr = proto.PythonUDF()
         expr.output_type = self._output_type
         expr.eval_type = self._eval_type
         expr.command = self._command
+        expr.python_ver = self._python_ver
         return expr
 
     def __repr__(self) -> str:
         return (
             f"{self._output_type}, {self._eval_type}, "
-            f"{self._command}"  # type: ignore[str-bytes-safe]
+            f"{self._command}, f{self._python_ver}"  # type: ignore[str-bytes-safe]
         )
 
 
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 3a06e80c21e..92d9e6a610a 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...]
+    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...]
 )
 
 
@@ -345,8 +345,8 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4846
     _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4862
     _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5173
-    _PYTHONUDF._serialized_start = 5175
-    _PYTHONUDF._serialized_end = 5274
-    _SCALARSCALAUDF._serialized_start = 5277
-    _SCALARSCALAUDF._serialized_end = 5461
+    _PYTHONUDF._serialized_start = 5176
+    _PYTHONUDF._serialized_end = 5306
+    _SCALARSCALAUDF._serialized_start = 5309
+    _SCALARSCALAUDF._serialized_end = 5493
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 604672a9ad7..934e0016c90 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1163,23 +1163,34 @@ class PythonUDF(google.protobuf.message.Message):
     OUTPUT_TYPE_FIELD_NUMBER: builtins.int
     EVAL_TYPE_FIELD_NUMBER: builtins.int
     COMMAND_FIELD_NUMBER: builtins.int
+    PYTHON_VER_FIELD_NUMBER: builtins.int
     output_type: builtins.str
     """(Required) Output type of the Python UDF"""
     eval_type: builtins.int
     """(Required) EvalType of the Python UDF"""
     command: builtins.bytes
     """(Required) The encoded commands of the Python UDF"""
+    python_ver: builtins.str
+    """(Required) Python version being used in the client."""
     def __init__(
         self,
         *,
         output_type: builtins.str = ...,
         eval_type: builtins.int = ...,
         command: builtins.bytes = ...,
+        python_ver: builtins.str = ...,
     ) -> None: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
-            "command", b"command", "eval_type", b"eval_type", "output_type", b"output_type"
+            "command",
+            b"command",
+            "eval_type",
+            b"eval_type",
+            "output_type",
+            b"output_type",
+            "python_ver",
+            b"python_ver",
         ],
     ) -> None: ...
 
diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py
index 46d45d8bc70..6571cf76929 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -21,6 +21,7 @@ from pyspark.sql.connect import check_dependencies
 
 check_dependencies(__name__, __file__)
 
+import sys
 import functools
 from typing import Callable, Any, TYPE_CHECKING, Optional
 
@@ -131,6 +132,7 @@ class UserDefinedFunction:
             output_type=data_type_str,
             eval_type=self.evalType,
             command=CloudPickleSerializer().dumps((self.func, self._returnType)),
+            python_ver="%d.%d" % sys.version_info[:2],
         )
         return Column(
             CommonInlineUserDefinedFunction(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org