You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/09 11:14:28 UTC

[spark] branch master updated: [SPARK-42630][CONNECT][PYTHON] Introduce UnparsedDataType and delay parsing DDL string until SparkConnectClient is available

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c0b1735c0bf [SPARK-42630][CONNECT][PYTHON] Introduce UnparsedDataType and delay parsing DDL string until SparkConnectClient is available
c0b1735c0bf is described below

commit c0b1735c0bfeb1ff645d146e262d7ccd036a590e
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Thu Mar 9 20:14:12 2023 +0900

    [SPARK-42630][CONNECT][PYTHON] Introduce UnparsedDataType and delay parsing DDL string until SparkConnectClient is available
    
    ### What changes were proposed in this pull request?
    
    Introduces `UnparsedDataType` and delays parsing DDL string for Python UDFs until `SparkConnectClient` is available.
    
    `UnparsedDataType` carries the DDL string and parse it in the server side.
    It should not be enclosed in other data types.
    
    Also changes `createDataFrame` to use the proto `DDLParse`.
    
    ### Why are the changes needed?
    
    Currently `parse_data_type` depends on `PySparkSession` that creates a local PySpark, but it won't be available in the client side.
    
    When `SparkConnectClient` is available, we can use the new proto `DDLParse` to parse the data types as string.
    
    ### Does this PR introduce _any_ user-facing change?
    
    The UDF's `returnType` attribute could be a string in Spark Connect if it is provided as string.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #40260 from ueshin/issues/SPARK-42630/ddl_parse.
    
    Authored-by: Takuya UESHIN <ue...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../main/protobuf/spark/connect/expressions.proto  |   4 +-
 .../src/main/protobuf/spark/connect/types.proto    |   8 ++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  49 ++++-----
 .../messages/ConnectProtoMessagesSuite.scala       |   4 +-
 python/pyspark/sql/connect/client.py               |  63 +++---------
 python/pyspark/sql/connect/expressions.py          |  38 ++++---
 .../pyspark/sql/connect/proto/expressions_pb2.py   |  12 +--
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  |  17 ++--
 python/pyspark/sql/connect/proto/types_pb2.py      | 113 ++++++++++++---------
 python/pyspark/sql/connect/proto/types_pb2.pyi     |  25 +++++
 python/pyspark/sql/connect/session.py              |   8 +-
 python/pyspark/sql/connect/types.py                |  83 +++++++++++----
 python/pyspark/sql/connect/udf.py                  |  14 +--
 .../pyspark/sql/tests/connect/test_parity_udf.py   |  29 ++----
 python/pyspark/sql/tests/test_udf.py               |   5 +
 15 files changed, 263 insertions(+), 209 deletions(-)

diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 0aee3ca13b9..9e949dab15a 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -326,7 +326,7 @@ message CommonInlineUserDefinedFunction {
 
 message PythonUDF {
   // (Required) Output type of the Python UDF
-  string output_type = 1;
+  DataType output_type = 1;
   // (Required) EvalType of the Python UDF
   int32 eval_type = 2;
   // (Required) The encoded commands of the Python UDF
@@ -351,7 +351,7 @@ message JavaUDF {
   string class_name = 1;
 
   // (Optional) Output type of the Java UDF
-  optional string output_type = 2;
+  optional DataType output_type = 2;
 
   // (Required) Indicate if the Java user-defined function is an aggregate function
   bool aggregate = 3;
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/types.proto b/connector/connect/common/src/main/protobuf/spark/connect/types.proto
index 03a6968af60..68833b5d220 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/types.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/types.proto
@@ -64,6 +64,9 @@ message DataType {
 
     // UserDefinedType
     UDT udt = 23;
+
+    // UnparsedDataType
+    Unparsed unparsed = 24;
   }
 
   message Boolean {
@@ -183,4 +186,9 @@ message DataType {
     optional string serialized_python_class = 4;
     DataType sql_type = 5;
   }
+
+  message Unparsed {
+    // (Required) The unparsed data type string
+    string data_type_string = 1;
+  }
 }
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index cd4da39d62f..010d0236c74 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -455,7 +455,7 @@ class SparkConnectPlanner(val session: SparkSession) {
   }
 
   private def transformToSchema(rel: proto.ToSchema): LogicalPlan = {
-    val schema = DataTypeProtoConverter.toCatalystType(rel.getSchema)
+    val schema = transformDataType(rel.getSchema)
     assert(schema.isInstanceOf[StructType])
 
     Dataset
@@ -616,6 +616,14 @@ class SparkConnectPlanner(val session: SparkSession) {
     }
   }
 
+  private def transformDataType(t: proto.DataType): DataType = {
+    t.getKindCase match {
+      case proto.DataType.KindCase.UNPARSED =>
+        parseDatatypeString(t.getUnparsed.getDataTypeString)
+      case _ => DataTypeProtoConverter.toCatalystType(t)
+    }
+  }
+
   private[connect] def parseDatatypeString(sqlText: String): DataType = {
     val parser = session.sessionState.sqlParser
     try {
@@ -960,13 +968,7 @@ class SparkConnectPlanner(val session: SparkSession) {
     PythonUDF(
       name = fun.getFunctionName,
       func = transformPythonFunction(udf),
-      dataType = DataType.parseTypeWithFallback(
-        schema = udf.getOutputType,
-        parser = DataType.fromDDL,
-        fallbackParser = DataType.fromJson) match {
-        case s: DataType => s
-        case other => throw InvalidPlanInput(s"Invalid return type $other")
-      },
+      dataType = transformDataType(udf.getOutputType),
       children = fun.getArgumentsList.asScala.map(transformExpression).toSeq,
       evalType = udf.getEvalType,
       udfDeterministic = fun.getDeterministic)
@@ -1220,9 +1222,7 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformCast(cast: proto.Expression.Cast): Expression = {
     cast.getCastToTypeCase match {
       case proto.Expression.Cast.CastToTypeCase.TYPE =>
-        Cast(
-          transformExpression(cast.getExpr),
-          DataTypeProtoConverter.toCatalystType(cast.getType))
+        Cast(transformExpression(cast.getExpr), transformDataType(cast.getType))
       case _ =>
         Cast(
           transformExpression(cast.getExpr),
@@ -1615,13 +1615,7 @@ class SparkConnectPlanner(val session: SparkSession) {
     val udpf = UserDefinedPythonFunction(
       name = fun.getFunctionName,
       func = function,
-      dataType = DataType.parseTypeWithFallback(
-        schema = udf.getOutputType,
-        parser = DataType.fromDDL,
-        fallbackParser = DataType.fromJson) match {
-        case s: DataType => s
-        case other => throw InvalidPlanInput(s"Invalid return type $other")
-      },
+      dataType = transformDataType(udf.getOutputType),
       pythonEvalType = udf.getEvalType,
       udfDeterministic = fun.getDeterministic)
 
@@ -1630,16 +1624,11 @@ class SparkConnectPlanner(val session: SparkSession) {
 
   private def handleRegisterJavaUDF(fun: proto.CommonInlineUserDefinedFunction): Unit = {
     val udf = fun.getJavaUdf
-    val dataType =
-      if (udf.hasOutputType) {
-        DataType.parseTypeWithFallback(
-          schema = udf.getOutputType,
-          parser = DataType.fromDDL,
-          fallbackParser = DataType.fromJson) match {
-          case s: DataType => s
-          case other => throw InvalidPlanInput(s"Invalid return type $other")
-        }
-      } else null
+    val dataType = if (udf.hasOutputType) {
+      transformDataType(udf.getOutputType)
+    } else {
+      null
+    }
     if (udf.getAggregate) {
       session.udf.registerJavaUDAF(fun.getFunctionName, udf.getClassName)
     } else {
@@ -1945,7 +1934,7 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformCreateExternalTable(
       getCreateExternalTable: proto.CreateExternalTable): LogicalPlan = {
     val schema = if (getCreateExternalTable.hasSchema) {
-      val struct = DataTypeProtoConverter.toCatalystType(getCreateExternalTable.getSchema)
+      val struct = transformDataType(getCreateExternalTable.getSchema)
       assert(struct.isInstanceOf[StructType])
       struct.asInstanceOf[StructType]
     } else {
@@ -1975,7 +1964,7 @@ class SparkConnectPlanner(val session: SparkSession) {
 
   private def transformCreateTable(getCreateTable: proto.CreateTable): LogicalPlan = {
     val schema = if (getCreateTable.hasSchema) {
-      val struct = DataTypeProtoConverter.toCatalystType(getCreateTable.getSchema)
+      val struct = transformDataType(getCreateTable.getSchema)
       assert(struct.isInstanceOf[StructType])
       struct.asInstanceOf[StructType]
     } else {
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
index 09462ce18c2..65c03a3c2e2 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
@@ -20,6 +20,8 @@ import com.google.protobuf.ByteString
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.common.DataTypeProtoConverter
+import org.apache.spark.sql.types.IntegerType
 
 class ConnectProtoMessagesSuite extends SparkFunSuite {
   test("UserContext can deal with extensions") {
@@ -61,7 +63,7 @@ class ConnectProtoMessagesSuite extends SparkFunSuite {
     val pythonUdf = proto.PythonUDF
       .newBuilder()
       .setEvalType(100)
-      .setOutputType("\"integer\"")
+      .setOutputType(DataTypeProtoConverter.toConnectProtoType(IntegerType))
       .setCommand(ByteString.copyFrom("command".getBytes()))
       .setPythonVer("3.10")
       .build()
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index baa6d641422..3c91661716e 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -31,7 +31,6 @@ import random
 import time
 import urllib.parse
 import uuid
-import json
 import sys
 from types import TracebackType
 from typing import (
@@ -72,13 +71,7 @@ from pyspark.sql.connect.expressions import (
     CommonInlineUserDefinedFunction,
     JavaUDF,
 )
-from pyspark.sql.connect.types import parse_data_type
-from pyspark.sql.types import (
-    DataType,
-    StructType,
-    StructField,
-)
-from pyspark.serializers import CloudPickleSerializer
+from pyspark.sql.types import DataType, StructType
 from pyspark.rdd import PythonEvalType
 
 
@@ -399,14 +392,14 @@ class PlanObservedMetrics:
 class AnalyzeResult:
     def __init__(
         self,
-        schema: Optional[pb2.DataType],
+        schema: Optional[DataType],
         explain_string: Optional[str],
         tree_string: Optional[str],
         is_local: Optional[bool],
         is_streaming: Optional[bool],
         input_files: Optional[List[str]],
         spark_version: Optional[str],
-        parsed: Optional[pb2.DataType],
+        parsed: Optional[DataType],
         is_same_semantics: Optional[bool],
     ):
         self.schema = schema
@@ -421,18 +414,18 @@ class AnalyzeResult:
 
     @classmethod
     def fromProto(cls, pb: Any) -> "AnalyzeResult":
-        schema: Optional[pb2.DataType] = None
+        schema: Optional[DataType] = None
         explain_string: Optional[str] = None
         tree_string: Optional[str] = None
         is_local: Optional[bool] = None
         is_streaming: Optional[bool] = None
         input_files: Optional[List[str]] = None
         spark_version: Optional[str] = None
-        parsed: Optional[pb2.DataType] = None
+        parsed: Optional[DataType] = None
         is_same_semantics: Optional[bool] = None
 
         if pb.HasField("schema"):
-            schema = pb.schema.schema
+            schema = types.proto_schema_to_pyspark_data_type(pb.schema.schema)
         elif pb.HasField("explain"):
             explain_string = pb.explain.explain_string
         elif pb.HasField("tree_string"):
@@ -446,7 +439,7 @@ class AnalyzeResult:
         elif pb.HasField("spark_version"):
             spark_version = pb.spark_version.version
         elif pb.HasField("ddl_parse"):
-            parsed = pb.ddl_parse.parsed
+            parsed = types.proto_schema_to_pyspark_data_type(pb.ddl_parse.parsed)
         elif pb.HasField("same_semantics"):
             is_same_semantics = pb.same_semantics.result
         else:
@@ -553,14 +546,11 @@ class SparkConnectClient(object):
         if name is None:
             name = f"fun_{uuid.uuid4().hex}"
 
-        # convert str return_type to DataType
-        if isinstance(return_type, str):
-            return_type = parse_data_type(return_type)
         # construct a PythonUDF
         py_udf = PythonUDF(
-            output_type=return_type.json(),
+            output_type=return_type,
             eval_type=eval_type,
-            command=CloudPickleSerializer().dumps((function, return_type)),
+            func=function,
             python_ver="%d.%d" % sys.version_info[:2],
         )
 
@@ -586,18 +576,11 @@ class SparkConnectClient(object):
         return_type: Optional["DataTypeOrString"] = None,
         aggregate: bool = False,
     ) -> None:
-        # convert str return_type to DataType
-        if isinstance(return_type, str):
-            return_type = parse_data_type(return_type)
-
         # construct a JavaUDF
         if return_type is None:
             java_udf = JavaUDF(class_name=javaClassName, aggregate=aggregate)
         else:
-            java_udf = JavaUDF(
-                class_name=javaClassName,
-                output_type=return_type.json(),
-            )
+            java_udf = JavaUDF(class_name=javaClassName, output_type=return_type)
         fun = CommonInlineUserDefinedFunction(
             function_name=name,
             function=java_udf,
@@ -660,9 +643,6 @@ class SparkConnectClient(object):
             pdf.attrs["observed_metrics"] = observed_metrics
         return pdf
 
-    def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType:
-        return types.proto_schema_to_pyspark_data_type(schema)
-
     def _proto_to_string(self, p: google.protobuf.message.Message) -> str:
         """
         Helper method to generate a one line string representation of the plan.
@@ -682,26 +662,11 @@ class SparkConnectClient(object):
         Return schema for given plan.
         """
         logger.info(f"Schema for plan: {self._proto_to_string(plan)}")
-        proto_schema = self._analyze(method="schema", plan=plan).schema
-        assert proto_schema is not None
+        schema = self._analyze(method="schema", plan=plan).schema
+        assert schema is not None
         # Server side should populate the struct field which is the schema.
-        assert proto_schema.HasField("struct")
-
-        fields = []
-        for f in proto_schema.struct.fields:
-            if f.HasField("metadata"):
-                metadata = json.loads(f.metadata)
-            else:
-                metadata = None
-            fields.append(
-                StructField(
-                    f.name,
-                    self._proto_schema_to_pyspark_schema(f.data_type),
-                    f.nullable,
-                    metadata,
-                )
-            )
-        return StructType(fields)
+        assert isinstance(schema, StructType)
+        return schema
 
     def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str:
         """
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index 64176327c16..5c122f40373 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -22,6 +22,7 @@ from typing import (
     cast,
     TYPE_CHECKING,
     Any,
+    Callable,
     Union,
     Sequence,
     Tuple,
@@ -36,6 +37,7 @@ from threading import Lock
 
 import numpy as np
 
+from pyspark.serializers import CloudPickleSerializer
 from pyspark.sql.types import (
     _from_numpy_type,
     DateType,
@@ -66,6 +68,7 @@ from pyspark.sql.connect.types import (
     JVM_INT_MAX,
     JVM_LONG_MIN,
     JVM_LONG_MAX,
+    UnparsedDataType,
     pyspark_types_to_proto_types,
 )
 
@@ -496,29 +499,36 @@ class PythonUDF:
 
     def __init__(
         self,
-        output_type: str,
+        output_type: Union[DataType, str],
         eval_type: int,
-        command: bytes,
+        func: Callable[..., Any],
         python_ver: str,
     ) -> None:
-        self._output_type = output_type
+        self._output_type: DataType = (
+            UnparsedDataType(output_type) if isinstance(output_type, str) else output_type
+        )
         self._eval_type = eval_type
-        self._command = command
+        self._func = func
         self._python_ver = python_ver
 
     def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDF:
+        if isinstance(self._output_type, UnparsedDataType):
+            parsed = session._analyze(
+                method="ddl_parse", ddl_string=self._output_type.data_type_string
+            ).parsed
+            assert isinstance(parsed, DataType)
+            output_type = parsed
+        else:
+            output_type = self._output_type
         expr = proto.PythonUDF()
-        expr.output_type = self._output_type
+        expr.output_type.CopyFrom(pyspark_types_to_proto_types(output_type))
         expr.eval_type = self._eval_type
-        expr.command = self._command
+        expr.command = CloudPickleSerializer().dumps((self._func, output_type))
         expr.python_ver = self._python_ver
         return expr
 
     def __repr__(self) -> str:
-        return (
-            f"{self._output_type}, {self._eval_type}, "
-            f"{self._command}, f{self._python_ver}"  # type: ignore[str-bytes-safe]
-        )
+        return f"{self._output_type}, {self._eval_type}, {self._func}, f{self._python_ver}"
 
 
 class JavaUDF:
@@ -527,18 +537,20 @@ class JavaUDF:
     def __init__(
         self,
         class_name: str,
-        output_type: Optional[str] = None,
+        output_type: Optional[Union[DataType, str]] = None,
         aggregate: bool = False,
     ) -> None:
         self._class_name = class_name
-        self._output_type = output_type
+        self._output_type: Optional[DataType] = (
+            UnparsedDataType(output_type) if isinstance(output_type, str) else output_type
+        )
         self._aggregate = aggregate
 
     def to_plan(self, session: "SparkConnectClient") -> proto.JavaUDF:
         expr = proto.JavaUDF()
         expr.class_name = self._class_name
         if self._output_type is not None:
-            expr.output_type = self._output_type
+            expr.output_type.CopyFrom(pyspark_types_to_proto_types(self._output_type))
         expr.aggregate = self._aggregate
         return expr
 
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 24dd1136480..1814f52f539 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xa8\'\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunc [...]
+    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xa8\'\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunc [...]
 )
 
 
@@ -371,9 +371,9 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5140
     _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5504
     _PYTHONUDF._serialized_start = 5507
-    _PYTHONUDF._serialized_end = 5637
-    _SCALARSCALAUDF._serialized_start = 5640
-    _SCALARSCALAUDF._serialized_end = 5824
-    _JAVAUDF._serialized_start = 5826
-    _JAVAUDF._serialized_end = 5950
+    _PYTHONUDF._serialized_end = 5662
+    _SCALARSCALAUDF._serialized_start = 5665
+    _SCALARSCALAUDF._serialized_end = 5849
+    _JAVAUDF._serialized_start = 5852
+    _JAVAUDF._serialized_end = 6001
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 19b47c7ab91..3c8de8abb4e 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1242,8 +1242,9 @@ class PythonUDF(google.protobuf.message.Message):
     EVAL_TYPE_FIELD_NUMBER: builtins.int
     COMMAND_FIELD_NUMBER: builtins.int
     PYTHON_VER_FIELD_NUMBER: builtins.int
-    output_type: builtins.str
-    """(Required) Output type of the Python UDF"""
+    @property
+    def output_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
+        """(Required) Output type of the Python UDF"""
     eval_type: builtins.int
     """(Required) EvalType of the Python UDF"""
     command: builtins.bytes
@@ -1253,11 +1254,14 @@ class PythonUDF(google.protobuf.message.Message):
     def __init__(
         self,
         *,
-        output_type: builtins.str = ...,
+        output_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
         eval_type: builtins.int = ...,
         command: builtins.bytes = ...,
         python_ver: builtins.str = ...,
     ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["output_type", b"output_type"]
+    ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
@@ -1331,15 +1335,16 @@ class JavaUDF(google.protobuf.message.Message):
     AGGREGATE_FIELD_NUMBER: builtins.int
     class_name: builtins.str
     """(Required) Fully qualified name of Java class"""
-    output_type: builtins.str
-    """(Optional) Output type of the Java UDF"""
+    @property
+    def output_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
+        """(Optional) Output type of the Java UDF"""
     aggregate: builtins.bool
     """(Required) Indicate if the Java user-defined function is an aggregate function"""
     def __init__(
         self,
         *,
         class_name: builtins.str = ...,
-        output_type: builtins.str | None = ...,
+        output_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
         aggregate: builtins.bool = ...,
     ) -> None: ...
     def HasField(
diff --git a/python/pyspark/sql/connect/proto/types_pb2.py b/python/pyspark/sql/connect/proto/types_pb2.py
index cba709c19c9..eec58d5cee6 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.py
+++ b/python/pyspark/sql/connect/proto/types_pb2.py
@@ -30,7 +30,7 @@ _sym_db = _symbol_database.Default()
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xd1\x1f\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01( [...]
+    b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xc7 \n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01(\x0 [...]
 )
 
 
@@ -59,6 +59,7 @@ _DATATYPE_STRUCT = _DATATYPE.nested_types_by_name["Struct"]
 _DATATYPE_ARRAY = _DATATYPE.nested_types_by_name["Array"]
 _DATATYPE_MAP = _DATATYPE.nested_types_by_name["Map"]
 _DATATYPE_UDT = _DATATYPE.nested_types_by_name["UDT"]
+_DATATYPE_UNPARSED = _DATATYPE.nested_types_by_name["Unparsed"]
 DataType = _reflection.GeneratedProtocolMessageType(
     "DataType",
     (_message.Message,),
@@ -279,6 +280,15 @@ DataType = _reflection.GeneratedProtocolMessageType(
                 # @@protoc_insertion_point(class_scope:spark.connect.DataType.UDT)
             },
         ),
+        "Unparsed": _reflection.GeneratedProtocolMessageType(
+            "Unparsed",
+            (_message.Message,),
+            {
+                "DESCRIPTOR": _DATATYPE_UNPARSED,
+                "__module__": "spark.connect.types_pb2"
+                # @@protoc_insertion_point(class_scope:spark.connect.DataType.Unparsed)
+            },
+        ),
         "DESCRIPTOR": _DATATYPE,
         "__module__": "spark.connect.types_pb2"
         # @@protoc_insertion_point(class_scope:spark.connect.DataType)
@@ -309,59 +319,62 @@ _sym_db.RegisterMessage(DataType.Struct)
 _sym_db.RegisterMessage(DataType.Array)
 _sym_db.RegisterMessage(DataType.Map)
 _sym_db.RegisterMessage(DataType.UDT)
+_sym_db.RegisterMessage(DataType.Unparsed)
 
 if _descriptor._USE_C_DESCRIPTORS == False:
 
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
     _DATATYPE._serialized_start = 45
-    _DATATYPE._serialized_end = 4094
-    _DATATYPE_BOOLEAN._serialized_start = 1470
-    _DATATYPE_BOOLEAN._serialized_end = 1537
-    _DATATYPE_BYTE._serialized_start = 1539
-    _DATATYPE_BYTE._serialized_end = 1603
-    _DATATYPE_SHORT._serialized_start = 1605
-    _DATATYPE_SHORT._serialized_end = 1670
-    _DATATYPE_INTEGER._serialized_start = 1672
-    _DATATYPE_INTEGER._serialized_end = 1739
-    _DATATYPE_LONG._serialized_start = 1741
-    _DATATYPE_LONG._serialized_end = 1805
-    _DATATYPE_FLOAT._serialized_start = 1807
-    _DATATYPE_FLOAT._serialized_end = 1872
-    _DATATYPE_DOUBLE._serialized_start = 1874
-    _DATATYPE_DOUBLE._serialized_end = 1940
-    _DATATYPE_STRING._serialized_start = 1942
-    _DATATYPE_STRING._serialized_end = 2008
-    _DATATYPE_BINARY._serialized_start = 2010
-    _DATATYPE_BINARY._serialized_end = 2076
-    _DATATYPE_NULL._serialized_start = 2078
-    _DATATYPE_NULL._serialized_end = 2142
-    _DATATYPE_TIMESTAMP._serialized_start = 2144
-    _DATATYPE_TIMESTAMP._serialized_end = 2213
-    _DATATYPE_DATE._serialized_start = 2215
-    _DATATYPE_DATE._serialized_end = 2279
-    _DATATYPE_TIMESTAMPNTZ._serialized_start = 2281
-    _DATATYPE_TIMESTAMPNTZ._serialized_end = 2353
-    _DATATYPE_CALENDARINTERVAL._serialized_start = 2355
-    _DATATYPE_CALENDARINTERVAL._serialized_end = 2431
-    _DATATYPE_YEARMONTHINTERVAL._serialized_start = 2434
-    _DATATYPE_YEARMONTHINTERVAL._serialized_end = 2613
-    _DATATYPE_DAYTIMEINTERVAL._serialized_start = 2616
-    _DATATYPE_DAYTIMEINTERVAL._serialized_end = 2793
-    _DATATYPE_CHAR._serialized_start = 2795
-    _DATATYPE_CHAR._serialized_end = 2883
-    _DATATYPE_VARCHAR._serialized_start = 2885
-    _DATATYPE_VARCHAR._serialized_end = 2976
-    _DATATYPE_DECIMAL._serialized_start = 2979
-    _DATATYPE_DECIMAL._serialized_end = 3132
-    _DATATYPE_STRUCTFIELD._serialized_start = 3135
-    _DATATYPE_STRUCTFIELD._serialized_end = 3296
-    _DATATYPE_STRUCT._serialized_start = 3298
-    _DATATYPE_STRUCT._serialized_end = 3425
-    _DATATYPE_ARRAY._serialized_start = 3428
-    _DATATYPE_ARRAY._serialized_end = 3590
-    _DATATYPE_MAP._serialized_start = 3593
-    _DATATYPE_MAP._serialized_end = 3812
-    _DATATYPE_UDT._serialized_start = 3815
-    _DATATYPE_UDT._serialized_end = 4086
+    _DATATYPE._serialized_end = 4212
+    _DATATYPE_BOOLEAN._serialized_start = 1534
+    _DATATYPE_BOOLEAN._serialized_end = 1601
+    _DATATYPE_BYTE._serialized_start = 1603
+    _DATATYPE_BYTE._serialized_end = 1667
+    _DATATYPE_SHORT._serialized_start = 1669
+    _DATATYPE_SHORT._serialized_end = 1734
+    _DATATYPE_INTEGER._serialized_start = 1736
+    _DATATYPE_INTEGER._serialized_end = 1803
+    _DATATYPE_LONG._serialized_start = 1805
+    _DATATYPE_LONG._serialized_end = 1869
+    _DATATYPE_FLOAT._serialized_start = 1871
+    _DATATYPE_FLOAT._serialized_end = 1936
+    _DATATYPE_DOUBLE._serialized_start = 1938
+    _DATATYPE_DOUBLE._serialized_end = 2004
+    _DATATYPE_STRING._serialized_start = 2006
+    _DATATYPE_STRING._serialized_end = 2072
+    _DATATYPE_BINARY._serialized_start = 2074
+    _DATATYPE_BINARY._serialized_end = 2140
+    _DATATYPE_NULL._serialized_start = 2142
+    _DATATYPE_NULL._serialized_end = 2206
+    _DATATYPE_TIMESTAMP._serialized_start = 2208
+    _DATATYPE_TIMESTAMP._serialized_end = 2277
+    _DATATYPE_DATE._serialized_start = 2279
+    _DATATYPE_DATE._serialized_end = 2343
+    _DATATYPE_TIMESTAMPNTZ._serialized_start = 2345
+    _DATATYPE_TIMESTAMPNTZ._serialized_end = 2417
+    _DATATYPE_CALENDARINTERVAL._serialized_start = 2419
+    _DATATYPE_CALENDARINTERVAL._serialized_end = 2495
+    _DATATYPE_YEARMONTHINTERVAL._serialized_start = 2498
+    _DATATYPE_YEARMONTHINTERVAL._serialized_end = 2677
+    _DATATYPE_DAYTIMEINTERVAL._serialized_start = 2680
+    _DATATYPE_DAYTIMEINTERVAL._serialized_end = 2857
+    _DATATYPE_CHAR._serialized_start = 2859
+    _DATATYPE_CHAR._serialized_end = 2947
+    _DATATYPE_VARCHAR._serialized_start = 2949
+    _DATATYPE_VARCHAR._serialized_end = 3040
+    _DATATYPE_DECIMAL._serialized_start = 3043
+    _DATATYPE_DECIMAL._serialized_end = 3196
+    _DATATYPE_STRUCTFIELD._serialized_start = 3199
+    _DATATYPE_STRUCTFIELD._serialized_end = 3360
+    _DATATYPE_STRUCT._serialized_start = 3362
+    _DATATYPE_STRUCT._serialized_end = 3489
+    _DATATYPE_ARRAY._serialized_start = 3492
+    _DATATYPE_ARRAY._serialized_end = 3654
+    _DATATYPE_MAP._serialized_start = 3657
+    _DATATYPE_MAP._serialized_end = 3876
+    _DATATYPE_UDT._serialized_start = 3879
+    _DATATYPE_UDT._serialized_end = 4150
+    _DATATYPE_UNPARSED._serialized_start = 4152
+    _DATATYPE_UNPARSED._serialized_end = 4204
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/types_pb2.pyi b/python/pyspark/sql/connect/proto/types_pb2.pyi
index d09d78ac7b0..956701b4c36 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/types_pb2.pyi
@@ -716,6 +716,21 @@ class DataType(google.protobuf.message.Message):
             ],
         ) -> typing_extensions.Literal["serialized_python_class"] | None: ...
 
+    class Unparsed(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        DATA_TYPE_STRING_FIELD_NUMBER: builtins.int
+        data_type_string: builtins.str
+        """(Required) The unparsed data type string"""
+        def __init__(
+            self,
+            *,
+            data_type_string: builtins.str = ...,
+        ) -> None: ...
+        def ClearField(
+            self, field_name: typing_extensions.Literal["data_type_string", b"data_type_string"]
+        ) -> None: ...
+
     NULL_FIELD_NUMBER: builtins.int
     BINARY_FIELD_NUMBER: builtins.int
     BOOLEAN_FIELD_NUMBER: builtins.int
@@ -739,6 +754,7 @@ class DataType(google.protobuf.message.Message):
     STRUCT_FIELD_NUMBER: builtins.int
     MAP_FIELD_NUMBER: builtins.int
     UDT_FIELD_NUMBER: builtins.int
+    UNPARSED_FIELD_NUMBER: builtins.int
     @property
     def null(self) -> global___DataType.NULL: ...
     @property
@@ -791,6 +807,9 @@ class DataType(google.protobuf.message.Message):
     @property
     def udt(self) -> global___DataType.UDT:
         """UserDefinedType"""
+    @property
+    def unparsed(self) -> global___DataType.Unparsed:
+        """UnparsedDataType"""
     def __init__(
         self,
         *,
@@ -817,6 +836,7 @@ class DataType(google.protobuf.message.Message):
         struct: global___DataType.Struct | None = ...,
         map: global___DataType.Map | None = ...,
         udt: global___DataType.UDT | None = ...,
+        unparsed: global___DataType.Unparsed | None = ...,
     ) -> None: ...
     def HasField(
         self,
@@ -865,6 +885,8 @@ class DataType(google.protobuf.message.Message):
             b"timestamp_ntz",
             "udt",
             b"udt",
+            "unparsed",
+            b"unparsed",
             "var_char",
             b"var_char",
             "year_month_interval",
@@ -918,6 +940,8 @@ class DataType(google.protobuf.message.Message):
             b"timestamp_ntz",
             "udt",
             b"udt",
+            "unparsed",
+            b"unparsed",
             "var_char",
             b"var_char",
             "year_month_interval",
@@ -950,6 +974,7 @@ class DataType(google.protobuf.message.Message):
         "struct",
         "map",
         "udt",
+        "unparsed",
     ] | None: ...
 
 global___DataType = DataType
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index 475bd2fb6bd..9d9af112da4 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -321,7 +321,13 @@ class SparkSession:
                 # we can not infer the schema from the data itself.
                 warnings.warn("failed to infer the schema from data")
                 if _schema is None and _schema_str is not None:
-                    _schema = self.createDataFrame([], schema=_schema_str).schema
+                    _parsed = self.client._analyze(
+                        method="ddl_parse", ddl_string=_schema_str
+                    ).parsed
+                    if isinstance(_parsed, StructType):
+                        _schema = _parsed
+                    elif isinstance(_parsed, DataType):
+                        _schema = StructType().add("value", _parsed)
                 if _schema is None or not isinstance(_schema, StructType):
                     raise ValueError(
                         "Some of types cannot be determined after inferring, "
diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py
index 8e91709b8fc..d3c0fbc0272 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -22,7 +22,7 @@ import json
 
 import pyarrow as pa
 
-from typing import Optional
+from typing import Any, Dict, Optional
 
 from pyspark.sql.types import (
     DataType,
@@ -51,7 +51,6 @@ from pyspark.sql.types import (
 )
 
 import pyspark.sql.connect.proto as pb2
-from pyspark.sql.utils import is_remote
 
 
 JVM_BYTE_MIN: int = -(1 << 7)
@@ -64,6 +63,63 @@ JVM_LONG_MIN: int = -(1 << 63)
 JVM_LONG_MAX: int = (1 << 63) - 1
 
 
+class UnparsedDataType(DataType):
+    """
+    Unparsed data type.
+
+    The data type string will be parsed later.
+
+    Parameters
+    ----------
+    data_type_string : str
+        The data type string format equals :class:`DataType.simpleString`,
+        except that the top level struct type can omit the ``struct<>``.
+        This also supports a schema in a DDL-formatted string and case-insensitive strings.
+
+    Examples
+    --------
+    >>> from pyspark.sql.connect.types import UnparsedDataType
+
+    >>> UnparsedDataType("int ")
+    UnparsedDataType('int ')
+    >>> UnparsedDataType("INT ")
+    UnparsedDataType('INT ')
+    >>> UnparsedDataType("a: byte, b: decimal(  16 , 8   ) ")
+    UnparsedDataType('a: byte, b: decimal(  16 , 8   ) ')
+    >>> UnparsedDataType("a DOUBLE, b STRING")
+    UnparsedDataType('a DOUBLE, b STRING')
+    >>> UnparsedDataType("a DOUBLE, b CHAR( 50 )")
+    UnparsedDataType('a DOUBLE, b CHAR( 50 )')
+    >>> UnparsedDataType("a DOUBLE, b VARCHAR( 50 )")
+    UnparsedDataType('a DOUBLE, b VARCHAR( 50 )')
+    >>> UnparsedDataType("a: array< short>")
+    UnparsedDataType('a: array< short>')
+    >>> UnparsedDataType(" map<string , string > ")
+    UnparsedDataType(' map<string , string > ')
+    """
+
+    def __init__(self, data_type_string: str):
+        self.data_type_string = data_type_string
+
+    def simpleString(self) -> str:
+        return "unparsed(%s)" % repr(self.data_type_string)
+
+    def __repr__(self) -> str:
+        return "UnparsedDataType(%s)" % repr(self.data_type_string)
+
+    def jsonValue(self) -> Dict[str, Any]:
+        raise AssertionError("Invalid call to jsonValue on unresolved object")
+
+    def needConversion(self) -> bool:
+        raise AssertionError("Invalid call to needConversion on unresolved object")
+
+    def toInternal(self, obj: Any) -> Any:
+        raise AssertionError("Invalid call to toInternal on unresolved object")
+
+    def fromInternal(self, obj: Any) -> Any:
+        raise AssertionError("Invalid call to fromInternal on unresolved object")
+
+
 def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType:
     ret = pb2.DataType()
     if isinstance(data_type, NullType):
@@ -125,6 +181,9 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType:
             ret.udt.serialized_python_class = json_value["serializedClass"]
         ret.udt.python_class = json_value["pyClass"]
         ret.udt.sql_type.CopyFrom(pyspark_types_to_proto_types(data_type.sqlType()))
+    elif isinstance(data_type, UnparsedDataType):
+        data_type_string = data_type.data_type_string
+        ret.unparsed.data_type_string = data_type_string
     else:
         raise Exception(f"Unsupported data type {data_type}")
     return ret
@@ -339,23 +398,3 @@ def from_arrow_schema(arrow_schema: "pa.Schema") -> StructType:
             for field in arrow_schema
         ]
     )
-
-
-def parse_data_type(data_type: str) -> DataType:
-    # Currently we don't have a way to have a current Spark session in Spark Connect, and
-    # pyspark.sql.SparkSession has a centralized logic to control the session creation.
-    # So uses pyspark.sql.SparkSession for now. Should replace this to using the current
-    # Spark session for Spark Connect in the future.
-    from pyspark.sql import SparkSession as PySparkSession
-
-    assert is_remote()
-    return_type_schema = (
-        PySparkSession.builder.getOrCreate().createDataFrame(data=[], schema=data_type).schema
-    )
-    with_col_name = " " in data_type.strip()
-    if len(return_type_schema.fields) == 1 and not with_col_name:
-        # To match pyspark.sql.types._parse_datatype_string
-        return_type = return_type_schema.fields[0].dataType
-    else:
-        return_type = return_type_schema
-    return return_type
diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py
index bb7b70e613a..2128c908081 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -26,14 +26,13 @@ import functools
 from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union
 
 from pyspark.rdd import PythonEvalType
-from pyspark.serializers import CloudPickleSerializer
 from pyspark.sql.connect.expressions import (
     ColumnReference,
     PythonUDF,
     CommonInlineUserDefinedFunction,
 )
 from pyspark.sql.connect.column import Column
-from pyspark.sql.connect.types import parse_data_type
+from pyspark.sql.connect.types import UnparsedDataType
 from pyspark.sql.types import DataType, StringType
 from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration
 
@@ -99,8 +98,8 @@ class UserDefinedFunction:
             )
 
         self.func = func
-        self._returnType = (
-            parse_data_type(returnType) if isinstance(returnType, str) else returnType
+        self._returnType: DataType = (
+            UnparsedDataType(returnType) if isinstance(returnType, str) else returnType
         )
         self._name = name or (
             func.__name__ if hasattr(func, "__name__") else func.__class__.__name__
@@ -116,13 +115,10 @@ class UserDefinedFunction:
         ]
         arg_exprs = [col._expr for col in arg_cols]
 
-        data_type_str = (
-            self._returnType.json() if isinstance(self._returnType, DataType) else self._returnType
-        )
         py_udf = PythonUDF(
-            output_type=data_type_str,
+            output_type=self._returnType,
             eval_type=self.evalType,
-            command=CloudPickleSerializer().dumps((self.func, self._returnType)),
+            func=self.func,
             python_ver="%d.%d" % sys.version_info[:2],
         )
         return CommonInlineUserDefinedFunction(
diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py b/python/pyspark/sql/tests/connect/test_parity_udf.py
index b38b4c28a25..50f0d36be5d 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udf.py
@@ -25,10 +25,8 @@ if should_test_connect:
 
     sql.udf.UserDefinedFunction = UserDefinedFunction
 
-from pyspark.errors import AnalysisException
 from pyspark.sql.tests.test_udf import BaseUDFTestsMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.sql.types import IntegerType
 
 
 class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
@@ -91,8 +89,8 @@ class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
         super().test_worker_original_stdin_closed()
 
     @unittest.skip("Spark Connect does not support SQLContext but the test depends on it.")
-    def test_udf(self):
-        super().test_udf()
+    def test_udf_on_sql_context(self):
+        super().test_udf_on_sql_context()
 
     # TODO(SPARK-42247): implement `UserDefinedFunction.returnType`
     @unittest.skip("Fails in Spark Connect, should enable.")
@@ -104,22 +102,13 @@ class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
     def test_udf_registration_return_type_none(self):
         super().test_udf_registration_return_type_none()
 
-    def test_non_existed_udf(self):
-        spark = self.spark
-        self.assertRaisesRegex(
-            AnalysisException,
-            "Can not load class non_existed_udf",
-            lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"),
-        )
-
-    def test_udf_registration_returns_udf(self):
-        df = self.spark.range(10)
-        add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType())
-
-        self.assertListEqual(
-            df.selectExpr("add_three(id) AS plus_three").collect(),
-            df.select(add_three("id").alias("plus_three")).collect(),
-        )
+    @unittest.skip("Spark Connect does not support SQLContext but the test depends on it.")
+    def test_non_existed_udf_with_sql_context(self):
+        super().test_non_existed_udf_with_sql_context()
+
+    @unittest.skip("Spark Connect does not support SQLContext but the test depends on it.")
+    def test_udf_registration_returns_udf_on_sql_context(self):
+        super().test_udf_registration_returns_udf_on_sql_context()
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py
index 0f93babbd6c..b766b0c0178 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -74,6 +74,7 @@ class BaseUDFTestsMixin(object):
         [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
         self.assertEqual(row[0], 5)
 
+    def test_udf_on_sql_context(self):
         # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias.
         sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
         sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
@@ -369,6 +370,9 @@ class BaseUDFTestsMixin(object):
             df.select(add_three("id").alias("plus_three")).collect(),
         )
 
+    def test_udf_registration_returns_udf_on_sql_context(self):
+        df = self.spark.range(10)
+
         # This is to check if a 'SQLContext.udf' can call its alias.
         sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
         add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType())
@@ -416,6 +420,7 @@ class BaseUDFTestsMixin(object):
             lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"),
         )
 
+    def test_non_existed_udf_with_sql_context(self):
         # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias.
         sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
         self.assertRaisesRegex(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org