You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/09 11:14:28 UTC
[spark] branch master updated: [SPARK-42630][CONNECT][PYTHON] Introduce UnparsedDataType and delay parsing DDL string until SparkConnectClient is available
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c0b1735c0bf [SPARK-42630][CONNECT][PYTHON] Introduce UnparsedDataType and delay parsing DDL string until SparkConnectClient is available
c0b1735c0bf is described below
commit c0b1735c0bfeb1ff645d146e262d7ccd036a590e
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Thu Mar 9 20:14:12 2023 +0900
[SPARK-42630][CONNECT][PYTHON] Introduce UnparsedDataType and delay parsing DDL string until SparkConnectClient is available
### What changes were proposed in this pull request?
Introduces `UnparsedDataType` and delays parsing DDL string for Python UDFs until `SparkConnectClient` is available.
`UnparsedDataType` carries the DDL string and parse it in the server side.
It should not be enclosed in other data types.
Also changes `createDataFrame` to use the proto `DDLParse`.
### Why are the changes needed?
Currently `parse_data_type` depends on `PySparkSession` that creates a local PySpark, but it won't be available in the client side.
When `SparkConnectClient` is available, we can use the new proto `DDLParse` to parse the data types as string.
### Does this PR introduce _any_ user-facing change?
The UDF's `returnType` attribute could be a string in Spark Connect if it is provided as string.
### How was this patch tested?
Existing tests.
Closes #40260 from ueshin/issues/SPARK-42630/ddl_parse.
Authored-by: Takuya UESHIN <ue...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../main/protobuf/spark/connect/expressions.proto | 4 +-
.../src/main/protobuf/spark/connect/types.proto | 8 ++
.../sql/connect/planner/SparkConnectPlanner.scala | 49 ++++-----
.../messages/ConnectProtoMessagesSuite.scala | 4 +-
python/pyspark/sql/connect/client.py | 63 +++---------
python/pyspark/sql/connect/expressions.py | 38 ++++---
.../pyspark/sql/connect/proto/expressions_pb2.py | 12 +--
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 17 ++--
python/pyspark/sql/connect/proto/types_pb2.py | 113 ++++++++++++---------
python/pyspark/sql/connect/proto/types_pb2.pyi | 25 +++++
python/pyspark/sql/connect/session.py | 8 +-
python/pyspark/sql/connect/types.py | 83 +++++++++++----
python/pyspark/sql/connect/udf.py | 14 +--
.../pyspark/sql/tests/connect/test_parity_udf.py | 29 ++----
python/pyspark/sql/tests/test_udf.py | 5 +
15 files changed, 263 insertions(+), 209 deletions(-)
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 0aee3ca13b9..9e949dab15a 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -326,7 +326,7 @@ message CommonInlineUserDefinedFunction {
message PythonUDF {
// (Required) Output type of the Python UDF
- string output_type = 1;
+ DataType output_type = 1;
// (Required) EvalType of the Python UDF
int32 eval_type = 2;
// (Required) The encoded commands of the Python UDF
@@ -351,7 +351,7 @@ message JavaUDF {
string class_name = 1;
// (Optional) Output type of the Java UDF
- optional string output_type = 2;
+ optional DataType output_type = 2;
// (Required) Indicate if the Java user-defined function is an aggregate function
bool aggregate = 3;
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/types.proto b/connector/connect/common/src/main/protobuf/spark/connect/types.proto
index 03a6968af60..68833b5d220 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/types.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/types.proto
@@ -64,6 +64,9 @@ message DataType {
// UserDefinedType
UDT udt = 23;
+
+ // UnparsedDataType
+ Unparsed unparsed = 24;
}
message Boolean {
@@ -183,4 +186,9 @@ message DataType {
optional string serialized_python_class = 4;
DataType sql_type = 5;
}
+
+ message Unparsed {
+ // (Required) The unparsed data type string
+ string data_type_string = 1;
+ }
}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index cd4da39d62f..010d0236c74 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -455,7 +455,7 @@ class SparkConnectPlanner(val session: SparkSession) {
}
private def transformToSchema(rel: proto.ToSchema): LogicalPlan = {
- val schema = DataTypeProtoConverter.toCatalystType(rel.getSchema)
+ val schema = transformDataType(rel.getSchema)
assert(schema.isInstanceOf[StructType])
Dataset
@@ -616,6 +616,14 @@ class SparkConnectPlanner(val session: SparkSession) {
}
}
+ private def transformDataType(t: proto.DataType): DataType = {
+ t.getKindCase match {
+ case proto.DataType.KindCase.UNPARSED =>
+ parseDatatypeString(t.getUnparsed.getDataTypeString)
+ case _ => DataTypeProtoConverter.toCatalystType(t)
+ }
+ }
+
private[connect] def parseDatatypeString(sqlText: String): DataType = {
val parser = session.sessionState.sqlParser
try {
@@ -960,13 +968,7 @@ class SparkConnectPlanner(val session: SparkSession) {
PythonUDF(
name = fun.getFunctionName,
func = transformPythonFunction(udf),
- dataType = DataType.parseTypeWithFallback(
- schema = udf.getOutputType,
- parser = DataType.fromDDL,
- fallbackParser = DataType.fromJson) match {
- case s: DataType => s
- case other => throw InvalidPlanInput(s"Invalid return type $other")
- },
+ dataType = transformDataType(udf.getOutputType),
children = fun.getArgumentsList.asScala.map(transformExpression).toSeq,
evalType = udf.getEvalType,
udfDeterministic = fun.getDeterministic)
@@ -1220,9 +1222,7 @@ class SparkConnectPlanner(val session: SparkSession) {
private def transformCast(cast: proto.Expression.Cast): Expression = {
cast.getCastToTypeCase match {
case proto.Expression.Cast.CastToTypeCase.TYPE =>
- Cast(
- transformExpression(cast.getExpr),
- DataTypeProtoConverter.toCatalystType(cast.getType))
+ Cast(transformExpression(cast.getExpr), transformDataType(cast.getType))
case _ =>
Cast(
transformExpression(cast.getExpr),
@@ -1615,13 +1615,7 @@ class SparkConnectPlanner(val session: SparkSession) {
val udpf = UserDefinedPythonFunction(
name = fun.getFunctionName,
func = function,
- dataType = DataType.parseTypeWithFallback(
- schema = udf.getOutputType,
- parser = DataType.fromDDL,
- fallbackParser = DataType.fromJson) match {
- case s: DataType => s
- case other => throw InvalidPlanInput(s"Invalid return type $other")
- },
+ dataType = transformDataType(udf.getOutputType),
pythonEvalType = udf.getEvalType,
udfDeterministic = fun.getDeterministic)
@@ -1630,16 +1624,11 @@ class SparkConnectPlanner(val session: SparkSession) {
private def handleRegisterJavaUDF(fun: proto.CommonInlineUserDefinedFunction): Unit = {
val udf = fun.getJavaUdf
- val dataType =
- if (udf.hasOutputType) {
- DataType.parseTypeWithFallback(
- schema = udf.getOutputType,
- parser = DataType.fromDDL,
- fallbackParser = DataType.fromJson) match {
- case s: DataType => s
- case other => throw InvalidPlanInput(s"Invalid return type $other")
- }
- } else null
+ val dataType = if (udf.hasOutputType) {
+ transformDataType(udf.getOutputType)
+ } else {
+ null
+ }
if (udf.getAggregate) {
session.udf.registerJavaUDAF(fun.getFunctionName, udf.getClassName)
} else {
@@ -1945,7 +1934,7 @@ class SparkConnectPlanner(val session: SparkSession) {
private def transformCreateExternalTable(
getCreateExternalTable: proto.CreateExternalTable): LogicalPlan = {
val schema = if (getCreateExternalTable.hasSchema) {
- val struct = DataTypeProtoConverter.toCatalystType(getCreateExternalTable.getSchema)
+ val struct = transformDataType(getCreateExternalTable.getSchema)
assert(struct.isInstanceOf[StructType])
struct.asInstanceOf[StructType]
} else {
@@ -1975,7 +1964,7 @@ class SparkConnectPlanner(val session: SparkSession) {
private def transformCreateTable(getCreateTable: proto.CreateTable): LogicalPlan = {
val schema = if (getCreateTable.hasSchema) {
- val struct = DataTypeProtoConverter.toCatalystType(getCreateTable.getSchema)
+ val struct = transformDataType(getCreateTable.getSchema)
assert(struct.isInstanceOf[StructType])
struct.asInstanceOf[StructType]
} else {
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
index 09462ce18c2..65c03a3c2e2 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
@@ -20,6 +20,8 @@ import com.google.protobuf.ByteString
import org.apache.spark.SparkFunSuite
import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.common.DataTypeProtoConverter
+import org.apache.spark.sql.types.IntegerType
class ConnectProtoMessagesSuite extends SparkFunSuite {
test("UserContext can deal with extensions") {
@@ -61,7 +63,7 @@ class ConnectProtoMessagesSuite extends SparkFunSuite {
val pythonUdf = proto.PythonUDF
.newBuilder()
.setEvalType(100)
- .setOutputType("\"integer\"")
+ .setOutputType(DataTypeProtoConverter.toConnectProtoType(IntegerType))
.setCommand(ByteString.copyFrom("command".getBytes()))
.setPythonVer("3.10")
.build()
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index baa6d641422..3c91661716e 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -31,7 +31,6 @@ import random
import time
import urllib.parse
import uuid
-import json
import sys
from types import TracebackType
from typing import (
@@ -72,13 +71,7 @@ from pyspark.sql.connect.expressions import (
CommonInlineUserDefinedFunction,
JavaUDF,
)
-from pyspark.sql.connect.types import parse_data_type
-from pyspark.sql.types import (
- DataType,
- StructType,
- StructField,
-)
-from pyspark.serializers import CloudPickleSerializer
+from pyspark.sql.types import DataType, StructType
from pyspark.rdd import PythonEvalType
@@ -399,14 +392,14 @@ class PlanObservedMetrics:
class AnalyzeResult:
def __init__(
self,
- schema: Optional[pb2.DataType],
+ schema: Optional[DataType],
explain_string: Optional[str],
tree_string: Optional[str],
is_local: Optional[bool],
is_streaming: Optional[bool],
input_files: Optional[List[str]],
spark_version: Optional[str],
- parsed: Optional[pb2.DataType],
+ parsed: Optional[DataType],
is_same_semantics: Optional[bool],
):
self.schema = schema
@@ -421,18 +414,18 @@ class AnalyzeResult:
@classmethod
def fromProto(cls, pb: Any) -> "AnalyzeResult":
- schema: Optional[pb2.DataType] = None
+ schema: Optional[DataType] = None
explain_string: Optional[str] = None
tree_string: Optional[str] = None
is_local: Optional[bool] = None
is_streaming: Optional[bool] = None
input_files: Optional[List[str]] = None
spark_version: Optional[str] = None
- parsed: Optional[pb2.DataType] = None
+ parsed: Optional[DataType] = None
is_same_semantics: Optional[bool] = None
if pb.HasField("schema"):
- schema = pb.schema.schema
+ schema = types.proto_schema_to_pyspark_data_type(pb.schema.schema)
elif pb.HasField("explain"):
explain_string = pb.explain.explain_string
elif pb.HasField("tree_string"):
@@ -446,7 +439,7 @@ class AnalyzeResult:
elif pb.HasField("spark_version"):
spark_version = pb.spark_version.version
elif pb.HasField("ddl_parse"):
- parsed = pb.ddl_parse.parsed
+ parsed = types.proto_schema_to_pyspark_data_type(pb.ddl_parse.parsed)
elif pb.HasField("same_semantics"):
is_same_semantics = pb.same_semantics.result
else:
@@ -553,14 +546,11 @@ class SparkConnectClient(object):
if name is None:
name = f"fun_{uuid.uuid4().hex}"
- # convert str return_type to DataType
- if isinstance(return_type, str):
- return_type = parse_data_type(return_type)
# construct a PythonUDF
py_udf = PythonUDF(
- output_type=return_type.json(),
+ output_type=return_type,
eval_type=eval_type,
- command=CloudPickleSerializer().dumps((function, return_type)),
+ func=function,
python_ver="%d.%d" % sys.version_info[:2],
)
@@ -586,18 +576,11 @@ class SparkConnectClient(object):
return_type: Optional["DataTypeOrString"] = None,
aggregate: bool = False,
) -> None:
- # convert str return_type to DataType
- if isinstance(return_type, str):
- return_type = parse_data_type(return_type)
-
# construct a JavaUDF
if return_type is None:
java_udf = JavaUDF(class_name=javaClassName, aggregate=aggregate)
else:
- java_udf = JavaUDF(
- class_name=javaClassName,
- output_type=return_type.json(),
- )
+ java_udf = JavaUDF(class_name=javaClassName, output_type=return_type)
fun = CommonInlineUserDefinedFunction(
function_name=name,
function=java_udf,
@@ -660,9 +643,6 @@ class SparkConnectClient(object):
pdf.attrs["observed_metrics"] = observed_metrics
return pdf
- def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType:
- return types.proto_schema_to_pyspark_data_type(schema)
-
def _proto_to_string(self, p: google.protobuf.message.Message) -> str:
"""
Helper method to generate a one line string representation of the plan.
@@ -682,26 +662,11 @@ class SparkConnectClient(object):
Return schema for given plan.
"""
logger.info(f"Schema for plan: {self._proto_to_string(plan)}")
- proto_schema = self._analyze(method="schema", plan=plan).schema
- assert proto_schema is not None
+ schema = self._analyze(method="schema", plan=plan).schema
+ assert schema is not None
# Server side should populate the struct field which is the schema.
- assert proto_schema.HasField("struct")
-
- fields = []
- for f in proto_schema.struct.fields:
- if f.HasField("metadata"):
- metadata = json.loads(f.metadata)
- else:
- metadata = None
- fields.append(
- StructField(
- f.name,
- self._proto_schema_to_pyspark_schema(f.data_type),
- f.nullable,
- metadata,
- )
- )
- return StructType(fields)
+ assert isinstance(schema, StructType)
+ return schema
def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str:
"""
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index 64176327c16..5c122f40373 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -22,6 +22,7 @@ from typing import (
cast,
TYPE_CHECKING,
Any,
+ Callable,
Union,
Sequence,
Tuple,
@@ -36,6 +37,7 @@ from threading import Lock
import numpy as np
+from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.types import (
_from_numpy_type,
DateType,
@@ -66,6 +68,7 @@ from pyspark.sql.connect.types import (
JVM_INT_MAX,
JVM_LONG_MIN,
JVM_LONG_MAX,
+ UnparsedDataType,
pyspark_types_to_proto_types,
)
@@ -496,29 +499,36 @@ class PythonUDF:
def __init__(
self,
- output_type: str,
+ output_type: Union[DataType, str],
eval_type: int,
- command: bytes,
+ func: Callable[..., Any],
python_ver: str,
) -> None:
- self._output_type = output_type
+ self._output_type: DataType = (
+ UnparsedDataType(output_type) if isinstance(output_type, str) else output_type
+ )
self._eval_type = eval_type
- self._command = command
+ self._func = func
self._python_ver = python_ver
def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDF:
+ if isinstance(self._output_type, UnparsedDataType):
+ parsed = session._analyze(
+ method="ddl_parse", ddl_string=self._output_type.data_type_string
+ ).parsed
+ assert isinstance(parsed, DataType)
+ output_type = parsed
+ else:
+ output_type = self._output_type
expr = proto.PythonUDF()
- expr.output_type = self._output_type
+ expr.output_type.CopyFrom(pyspark_types_to_proto_types(output_type))
expr.eval_type = self._eval_type
- expr.command = self._command
+ expr.command = CloudPickleSerializer().dumps((self._func, output_type))
expr.python_ver = self._python_ver
return expr
def __repr__(self) -> str:
- return (
- f"{self._output_type}, {self._eval_type}, "
- f"{self._command}, f{self._python_ver}" # type: ignore[str-bytes-safe]
- )
+ return f"{self._output_type}, {self._eval_type}, {self._func}, f{self._python_ver}"
class JavaUDF:
@@ -527,18 +537,20 @@ class JavaUDF:
def __init__(
self,
class_name: str,
- output_type: Optional[str] = None,
+ output_type: Optional[Union[DataType, str]] = None,
aggregate: bool = False,
) -> None:
self._class_name = class_name
- self._output_type = output_type
+ self._output_type: Optional[DataType] = (
+ UnparsedDataType(output_type) if isinstance(output_type, str) else output_type
+ )
self._aggregate = aggregate
def to_plan(self, session: "SparkConnectClient") -> proto.JavaUDF:
expr = proto.JavaUDF()
expr.class_name = self._class_name
if self._output_type is not None:
- expr.output_type = self._output_type
+ expr.output_type.CopyFrom(pyspark_types_to_proto_types(self._output_type))
expr.aggregate = self._aggregate
return expr
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 24dd1136480..1814f52f539 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xa8\'\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunc [...]
+ b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xa8\'\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunc [...]
)
@@ -371,9 +371,9 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5140
_COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5504
_PYTHONUDF._serialized_start = 5507
- _PYTHONUDF._serialized_end = 5637
- _SCALARSCALAUDF._serialized_start = 5640
- _SCALARSCALAUDF._serialized_end = 5824
- _JAVAUDF._serialized_start = 5826
- _JAVAUDF._serialized_end = 5950
+ _PYTHONUDF._serialized_end = 5662
+ _SCALARSCALAUDF._serialized_start = 5665
+ _SCALARSCALAUDF._serialized_end = 5849
+ _JAVAUDF._serialized_start = 5852
+ _JAVAUDF._serialized_end = 6001
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 19b47c7ab91..3c8de8abb4e 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1242,8 +1242,9 @@ class PythonUDF(google.protobuf.message.Message):
EVAL_TYPE_FIELD_NUMBER: builtins.int
COMMAND_FIELD_NUMBER: builtins.int
PYTHON_VER_FIELD_NUMBER: builtins.int
- output_type: builtins.str
- """(Required) Output type of the Python UDF"""
+ @property
+ def output_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
+ """(Required) Output type of the Python UDF"""
eval_type: builtins.int
"""(Required) EvalType of the Python UDF"""
command: builtins.bytes
@@ -1253,11 +1254,14 @@ class PythonUDF(google.protobuf.message.Message):
def __init__(
self,
*,
- output_type: builtins.str = ...,
+ output_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
eval_type: builtins.int = ...,
command: builtins.bytes = ...,
python_ver: builtins.str = ...,
) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["output_type", b"output_type"]
+ ) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
@@ -1331,15 +1335,16 @@ class JavaUDF(google.protobuf.message.Message):
AGGREGATE_FIELD_NUMBER: builtins.int
class_name: builtins.str
"""(Required) Fully qualified name of Java class"""
- output_type: builtins.str
- """(Optional) Output type of the Java UDF"""
+ @property
+ def output_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
+ """(Optional) Output type of the Java UDF"""
aggregate: builtins.bool
"""(Required) Indicate if the Java user-defined function is an aggregate function"""
def __init__(
self,
*,
class_name: builtins.str = ...,
- output_type: builtins.str | None = ...,
+ output_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
aggregate: builtins.bool = ...,
) -> None: ...
def HasField(
diff --git a/python/pyspark/sql/connect/proto/types_pb2.py b/python/pyspark/sql/connect/proto/types_pb2.py
index cba709c19c9..eec58d5cee6 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.py
+++ b/python/pyspark/sql/connect/proto/types_pb2.py
@@ -30,7 +30,7 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xd1\x1f\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01( [...]
+ b'\n\x19spark/connect/types.proto\x12\rspark.connect"\xc7 \n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01(\x0 [...]
)
@@ -59,6 +59,7 @@ _DATATYPE_STRUCT = _DATATYPE.nested_types_by_name["Struct"]
_DATATYPE_ARRAY = _DATATYPE.nested_types_by_name["Array"]
_DATATYPE_MAP = _DATATYPE.nested_types_by_name["Map"]
_DATATYPE_UDT = _DATATYPE.nested_types_by_name["UDT"]
+_DATATYPE_UNPARSED = _DATATYPE.nested_types_by_name["Unparsed"]
DataType = _reflection.GeneratedProtocolMessageType(
"DataType",
(_message.Message,),
@@ -279,6 +280,15 @@ DataType = _reflection.GeneratedProtocolMessageType(
# @@protoc_insertion_point(class_scope:spark.connect.DataType.UDT)
},
),
+ "Unparsed": _reflection.GeneratedProtocolMessageType(
+ "Unparsed",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _DATATYPE_UNPARSED,
+ "__module__": "spark.connect.types_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.DataType.Unparsed)
+ },
+ ),
"DESCRIPTOR": _DATATYPE,
"__module__": "spark.connect.types_pb2"
# @@protoc_insertion_point(class_scope:spark.connect.DataType)
@@ -309,59 +319,62 @@ _sym_db.RegisterMessage(DataType.Struct)
_sym_db.RegisterMessage(DataType.Array)
_sym_db.RegisterMessage(DataType.Map)
_sym_db.RegisterMessage(DataType.UDT)
+_sym_db.RegisterMessage(DataType.Unparsed)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
_DATATYPE._serialized_start = 45
- _DATATYPE._serialized_end = 4094
- _DATATYPE_BOOLEAN._serialized_start = 1470
- _DATATYPE_BOOLEAN._serialized_end = 1537
- _DATATYPE_BYTE._serialized_start = 1539
- _DATATYPE_BYTE._serialized_end = 1603
- _DATATYPE_SHORT._serialized_start = 1605
- _DATATYPE_SHORT._serialized_end = 1670
- _DATATYPE_INTEGER._serialized_start = 1672
- _DATATYPE_INTEGER._serialized_end = 1739
- _DATATYPE_LONG._serialized_start = 1741
- _DATATYPE_LONG._serialized_end = 1805
- _DATATYPE_FLOAT._serialized_start = 1807
- _DATATYPE_FLOAT._serialized_end = 1872
- _DATATYPE_DOUBLE._serialized_start = 1874
- _DATATYPE_DOUBLE._serialized_end = 1940
- _DATATYPE_STRING._serialized_start = 1942
- _DATATYPE_STRING._serialized_end = 2008
- _DATATYPE_BINARY._serialized_start = 2010
- _DATATYPE_BINARY._serialized_end = 2076
- _DATATYPE_NULL._serialized_start = 2078
- _DATATYPE_NULL._serialized_end = 2142
- _DATATYPE_TIMESTAMP._serialized_start = 2144
- _DATATYPE_TIMESTAMP._serialized_end = 2213
- _DATATYPE_DATE._serialized_start = 2215
- _DATATYPE_DATE._serialized_end = 2279
- _DATATYPE_TIMESTAMPNTZ._serialized_start = 2281
- _DATATYPE_TIMESTAMPNTZ._serialized_end = 2353
- _DATATYPE_CALENDARINTERVAL._serialized_start = 2355
- _DATATYPE_CALENDARINTERVAL._serialized_end = 2431
- _DATATYPE_YEARMONTHINTERVAL._serialized_start = 2434
- _DATATYPE_YEARMONTHINTERVAL._serialized_end = 2613
- _DATATYPE_DAYTIMEINTERVAL._serialized_start = 2616
- _DATATYPE_DAYTIMEINTERVAL._serialized_end = 2793
- _DATATYPE_CHAR._serialized_start = 2795
- _DATATYPE_CHAR._serialized_end = 2883
- _DATATYPE_VARCHAR._serialized_start = 2885
- _DATATYPE_VARCHAR._serialized_end = 2976
- _DATATYPE_DECIMAL._serialized_start = 2979
- _DATATYPE_DECIMAL._serialized_end = 3132
- _DATATYPE_STRUCTFIELD._serialized_start = 3135
- _DATATYPE_STRUCTFIELD._serialized_end = 3296
- _DATATYPE_STRUCT._serialized_start = 3298
- _DATATYPE_STRUCT._serialized_end = 3425
- _DATATYPE_ARRAY._serialized_start = 3428
- _DATATYPE_ARRAY._serialized_end = 3590
- _DATATYPE_MAP._serialized_start = 3593
- _DATATYPE_MAP._serialized_end = 3812
- _DATATYPE_UDT._serialized_start = 3815
- _DATATYPE_UDT._serialized_end = 4086
+ _DATATYPE._serialized_end = 4212
+ _DATATYPE_BOOLEAN._serialized_start = 1534
+ _DATATYPE_BOOLEAN._serialized_end = 1601
+ _DATATYPE_BYTE._serialized_start = 1603
+ _DATATYPE_BYTE._serialized_end = 1667
+ _DATATYPE_SHORT._serialized_start = 1669
+ _DATATYPE_SHORT._serialized_end = 1734
+ _DATATYPE_INTEGER._serialized_start = 1736
+ _DATATYPE_INTEGER._serialized_end = 1803
+ _DATATYPE_LONG._serialized_start = 1805
+ _DATATYPE_LONG._serialized_end = 1869
+ _DATATYPE_FLOAT._serialized_start = 1871
+ _DATATYPE_FLOAT._serialized_end = 1936
+ _DATATYPE_DOUBLE._serialized_start = 1938
+ _DATATYPE_DOUBLE._serialized_end = 2004
+ _DATATYPE_STRING._serialized_start = 2006
+ _DATATYPE_STRING._serialized_end = 2072
+ _DATATYPE_BINARY._serialized_start = 2074
+ _DATATYPE_BINARY._serialized_end = 2140
+ _DATATYPE_NULL._serialized_start = 2142
+ _DATATYPE_NULL._serialized_end = 2206
+ _DATATYPE_TIMESTAMP._serialized_start = 2208
+ _DATATYPE_TIMESTAMP._serialized_end = 2277
+ _DATATYPE_DATE._serialized_start = 2279
+ _DATATYPE_DATE._serialized_end = 2343
+ _DATATYPE_TIMESTAMPNTZ._serialized_start = 2345
+ _DATATYPE_TIMESTAMPNTZ._serialized_end = 2417
+ _DATATYPE_CALENDARINTERVAL._serialized_start = 2419
+ _DATATYPE_CALENDARINTERVAL._serialized_end = 2495
+ _DATATYPE_YEARMONTHINTERVAL._serialized_start = 2498
+ _DATATYPE_YEARMONTHINTERVAL._serialized_end = 2677
+ _DATATYPE_DAYTIMEINTERVAL._serialized_start = 2680
+ _DATATYPE_DAYTIMEINTERVAL._serialized_end = 2857
+ _DATATYPE_CHAR._serialized_start = 2859
+ _DATATYPE_CHAR._serialized_end = 2947
+ _DATATYPE_VARCHAR._serialized_start = 2949
+ _DATATYPE_VARCHAR._serialized_end = 3040
+ _DATATYPE_DECIMAL._serialized_start = 3043
+ _DATATYPE_DECIMAL._serialized_end = 3196
+ _DATATYPE_STRUCTFIELD._serialized_start = 3199
+ _DATATYPE_STRUCTFIELD._serialized_end = 3360
+ _DATATYPE_STRUCT._serialized_start = 3362
+ _DATATYPE_STRUCT._serialized_end = 3489
+ _DATATYPE_ARRAY._serialized_start = 3492
+ _DATATYPE_ARRAY._serialized_end = 3654
+ _DATATYPE_MAP._serialized_start = 3657
+ _DATATYPE_MAP._serialized_end = 3876
+ _DATATYPE_UDT._serialized_start = 3879
+ _DATATYPE_UDT._serialized_end = 4150
+ _DATATYPE_UNPARSED._serialized_start = 4152
+ _DATATYPE_UNPARSED._serialized_end = 4204
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/types_pb2.pyi b/python/pyspark/sql/connect/proto/types_pb2.pyi
index d09d78ac7b0..956701b4c36 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/types_pb2.pyi
@@ -716,6 +716,21 @@ class DataType(google.protobuf.message.Message):
],
) -> typing_extensions.Literal["serialized_python_class"] | None: ...
+ class Unparsed(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ DATA_TYPE_STRING_FIELD_NUMBER: builtins.int
+ data_type_string: builtins.str
+ """(Required) The unparsed data type string"""
+ def __init__(
+ self,
+ *,
+ data_type_string: builtins.str = ...,
+ ) -> None: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["data_type_string", b"data_type_string"]
+ ) -> None: ...
+
NULL_FIELD_NUMBER: builtins.int
BINARY_FIELD_NUMBER: builtins.int
BOOLEAN_FIELD_NUMBER: builtins.int
@@ -739,6 +754,7 @@ class DataType(google.protobuf.message.Message):
STRUCT_FIELD_NUMBER: builtins.int
MAP_FIELD_NUMBER: builtins.int
UDT_FIELD_NUMBER: builtins.int
+ UNPARSED_FIELD_NUMBER: builtins.int
@property
def null(self) -> global___DataType.NULL: ...
@property
@@ -791,6 +807,9 @@ class DataType(google.protobuf.message.Message):
@property
def udt(self) -> global___DataType.UDT:
"""UserDefinedType"""
+ @property
+ def unparsed(self) -> global___DataType.Unparsed:
+ """UnparsedDataType"""
def __init__(
self,
*,
@@ -817,6 +836,7 @@ class DataType(google.protobuf.message.Message):
struct: global___DataType.Struct | None = ...,
map: global___DataType.Map | None = ...,
udt: global___DataType.UDT | None = ...,
+ unparsed: global___DataType.Unparsed | None = ...,
) -> None: ...
def HasField(
self,
@@ -865,6 +885,8 @@ class DataType(google.protobuf.message.Message):
b"timestamp_ntz",
"udt",
b"udt",
+ "unparsed",
+ b"unparsed",
"var_char",
b"var_char",
"year_month_interval",
@@ -918,6 +940,8 @@ class DataType(google.protobuf.message.Message):
b"timestamp_ntz",
"udt",
b"udt",
+ "unparsed",
+ b"unparsed",
"var_char",
b"var_char",
"year_month_interval",
@@ -950,6 +974,7 @@ class DataType(google.protobuf.message.Message):
"struct",
"map",
"udt",
+ "unparsed",
] | None: ...
global___DataType = DataType
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index 475bd2fb6bd..9d9af112da4 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -321,7 +321,13 @@ class SparkSession:
# we can not infer the schema from the data itself.
warnings.warn("failed to infer the schema from data")
if _schema is None and _schema_str is not None:
- _schema = self.createDataFrame([], schema=_schema_str).schema
+ _parsed = self.client._analyze(
+ method="ddl_parse", ddl_string=_schema_str
+ ).parsed
+ if isinstance(_parsed, StructType):
+ _schema = _parsed
+ elif isinstance(_parsed, DataType):
+ _schema = StructType().add("value", _parsed)
if _schema is None or not isinstance(_schema, StructType):
raise ValueError(
"Some of types cannot be determined after inferring, "
diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py
index 8e91709b8fc..d3c0fbc0272 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -22,7 +22,7 @@ import json
import pyarrow as pa
-from typing import Optional
+from typing import Any, Dict, Optional
from pyspark.sql.types import (
DataType,
@@ -51,7 +51,6 @@ from pyspark.sql.types import (
)
import pyspark.sql.connect.proto as pb2
-from pyspark.sql.utils import is_remote
JVM_BYTE_MIN: int = -(1 << 7)
@@ -64,6 +63,63 @@ JVM_LONG_MIN: int = -(1 << 63)
JVM_LONG_MAX: int = (1 << 63) - 1
+class UnparsedDataType(DataType):
+ """
+ Unparsed data type.
+
+ The data type string will be parsed later.
+
+ Parameters
+ ----------
+ data_type_string : str
+ The data type string format equals :class:`DataType.simpleString`,
+ except that the top level struct type can omit the ``struct<>``.
+ This also supports a schema in a DDL-formatted string and case-insensitive strings.
+
+ Examples
+ --------
+ >>> from pyspark.sql.connect.types import UnparsedDataType
+
+ >>> UnparsedDataType("int ")
+ UnparsedDataType('int ')
+ >>> UnparsedDataType("INT ")
+ UnparsedDataType('INT ')
+ >>> UnparsedDataType("a: byte, b: decimal( 16 , 8 ) ")
+ UnparsedDataType('a: byte, b: decimal( 16 , 8 ) ')
+ >>> UnparsedDataType("a DOUBLE, b STRING")
+ UnparsedDataType('a DOUBLE, b STRING')
+ >>> UnparsedDataType("a DOUBLE, b CHAR( 50 )")
+ UnparsedDataType('a DOUBLE, b CHAR( 50 )')
+ >>> UnparsedDataType("a DOUBLE, b VARCHAR( 50 )")
+ UnparsedDataType('a DOUBLE, b VARCHAR( 50 )')
+ >>> UnparsedDataType("a: array< short>")
+ UnparsedDataType('a: array< short>')
+ >>> UnparsedDataType(" map<string , string > ")
+ UnparsedDataType(' map<string , string > ')
+ """
+
+ def __init__(self, data_type_string: str):
+ self.data_type_string = data_type_string
+
+ def simpleString(self) -> str:
+ return "unparsed(%s)" % repr(self.data_type_string)
+
+ def __repr__(self) -> str:
+ return "UnparsedDataType(%s)" % repr(self.data_type_string)
+
+ def jsonValue(self) -> Dict[str, Any]:
+ raise AssertionError("Invalid call to jsonValue on unresolved object")
+
+ def needConversion(self) -> bool:
+ raise AssertionError("Invalid call to needConversion on unresolved object")
+
+ def toInternal(self, obj: Any) -> Any:
+ raise AssertionError("Invalid call to toInternal on unresolved object")
+
+ def fromInternal(self, obj: Any) -> Any:
+ raise AssertionError("Invalid call to fromInternal on unresolved object")
+
+
def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType:
ret = pb2.DataType()
if isinstance(data_type, NullType):
@@ -125,6 +181,9 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType:
ret.udt.serialized_python_class = json_value["serializedClass"]
ret.udt.python_class = json_value["pyClass"]
ret.udt.sql_type.CopyFrom(pyspark_types_to_proto_types(data_type.sqlType()))
+ elif isinstance(data_type, UnparsedDataType):
+ data_type_string = data_type.data_type_string
+ ret.unparsed.data_type_string = data_type_string
else:
raise Exception(f"Unsupported data type {data_type}")
return ret
@@ -339,23 +398,3 @@ def from_arrow_schema(arrow_schema: "pa.Schema") -> StructType:
for field in arrow_schema
]
)
-
-
-def parse_data_type(data_type: str) -> DataType:
- # Currently we don't have a way to have a current Spark session in Spark Connect, and
- # pyspark.sql.SparkSession has a centralized logic to control the session creation.
- # So uses pyspark.sql.SparkSession for now. Should replace this to using the current
- # Spark session for Spark Connect in the future.
- from pyspark.sql import SparkSession as PySparkSession
-
- assert is_remote()
- return_type_schema = (
- PySparkSession.builder.getOrCreate().createDataFrame(data=[], schema=data_type).schema
- )
- with_col_name = " " in data_type.strip()
- if len(return_type_schema.fields) == 1 and not with_col_name:
- # To match pyspark.sql.types._parse_datatype_string
- return_type = return_type_schema.fields[0].dataType
- else:
- return_type = return_type_schema
- return return_type
diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py
index bb7b70e613a..2128c908081 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -26,14 +26,13 @@ import functools
from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union
from pyspark.rdd import PythonEvalType
-from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.connect.expressions import (
ColumnReference,
PythonUDF,
CommonInlineUserDefinedFunction,
)
from pyspark.sql.connect.column import Column
-from pyspark.sql.connect.types import parse_data_type
+from pyspark.sql.connect.types import UnparsedDataType
from pyspark.sql.types import DataType, StringType
from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration
@@ -99,8 +98,8 @@ class UserDefinedFunction:
)
self.func = func
- self._returnType = (
- parse_data_type(returnType) if isinstance(returnType, str) else returnType
+ self._returnType: DataType = (
+ UnparsedDataType(returnType) if isinstance(returnType, str) else returnType
)
self._name = name or (
func.__name__ if hasattr(func, "__name__") else func.__class__.__name__
@@ -116,13 +115,10 @@ class UserDefinedFunction:
]
arg_exprs = [col._expr for col in arg_cols]
- data_type_str = (
- self._returnType.json() if isinstance(self._returnType, DataType) else self._returnType
- )
py_udf = PythonUDF(
- output_type=data_type_str,
+ output_type=self._returnType,
eval_type=self.evalType,
- command=CloudPickleSerializer().dumps((self.func, self._returnType)),
+ func=self.func,
python_ver="%d.%d" % sys.version_info[:2],
)
return CommonInlineUserDefinedFunction(
diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py b/python/pyspark/sql/tests/connect/test_parity_udf.py
index b38b4c28a25..50f0d36be5d 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udf.py
@@ -25,10 +25,8 @@ if should_test_connect:
sql.udf.UserDefinedFunction = UserDefinedFunction
-from pyspark.errors import AnalysisException
from pyspark.sql.tests.test_udf import BaseUDFTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
-from pyspark.sql.types import IntegerType
class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
@@ -91,8 +89,8 @@ class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
super().test_worker_original_stdin_closed()
@unittest.skip("Spark Connect does not support SQLContext but the test depends on it.")
- def test_udf(self):
- super().test_udf()
+ def test_udf_on_sql_context(self):
+ super().test_udf_on_sql_context()
# TODO(SPARK-42247): implement `UserDefinedFunction.returnType`
@unittest.skip("Fails in Spark Connect, should enable.")
@@ -104,22 +102,13 @@ class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
def test_udf_registration_return_type_none(self):
super().test_udf_registration_return_type_none()
- def test_non_existed_udf(self):
- spark = self.spark
- self.assertRaisesRegex(
- AnalysisException,
- "Can not load class non_existed_udf",
- lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"),
- )
-
- def test_udf_registration_returns_udf(self):
- df = self.spark.range(10)
- add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType())
-
- self.assertListEqual(
- df.selectExpr("add_three(id) AS plus_three").collect(),
- df.select(add_three("id").alias("plus_three")).collect(),
- )
+ @unittest.skip("Spark Connect does not support SQLContext but the test depends on it.")
+ def test_non_existed_udf_with_sql_context(self):
+ super().test_non_existed_udf_with_sql_context()
+
+ @unittest.skip("Spark Connect does not support SQLContext but the test depends on it.")
+ def test_udf_registration_returns_udf_on_sql_context(self):
+ super().test_udf_registration_returns_udf_on_sql_context()
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py
index 0f93babbd6c..b766b0c0178 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -74,6 +74,7 @@ class BaseUDFTestsMixin(object):
[row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
self.assertEqual(row[0], 5)
+ def test_udf_on_sql_context(self):
# This is to check if a deprecated 'SQLContext.registerFunction' can call its alias.
sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
@@ -369,6 +370,9 @@ class BaseUDFTestsMixin(object):
df.select(add_three("id").alias("plus_three")).collect(),
)
+ def test_udf_registration_returns_udf_on_sql_context(self):
+ df = self.spark.range(10)
+
# This is to check if a 'SQLContext.udf' can call its alias.
sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType())
@@ -416,6 +420,7 @@ class BaseUDFTestsMixin(object):
lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"),
)
+ def test_non_existed_udf_with_sql_context(self):
# This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias.
sqlContext = SQLContext.getOrCreate(self.spark.sparkContext)
self.assertRaisesRegex(
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org