You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/03/10 02:19:15 UTC
[spark] branch branch-3.4 updated: [SPARK-42725][CONNECT][PYTHON] Make LiteralExpression support array params
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new a01f4d6ac8e [SPARK-42725][CONNECT][PYTHON] Make LiteralExpression support array params
a01f4d6ac8e is described below
commit a01f4d6ac8eb228fef79b21eb94235a64cceaa4d
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Fri Mar 10 10:18:44 2023 +0800
[SPARK-42725][CONNECT][PYTHON] Make LiteralExpression support array params
### What changes were proposed in this pull request?
Make LiteralExpression support array
### Why are the changes needed?
MLIib requires literal to carry the array params, like `IntArrayParam`, `DoubleArrayArrayParam`.
Note that this PR doesn't affect existing `functions.lit` method which apply unresolved `CreateArray` expression to support array input.
### Does this PR introduce _any_ user-facing change?
No, dev-only
### How was this patch tested?
added UT
Closes #40349 from zhengruifeng/connect_py_ml_lit.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
(cherry picked from commit d6d0fc74d36567c5163878656de787d6fb418604)
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../sql/expressions/LiteralProtoConverter.scala | 2 +-
.../main/protobuf/spark/connect/expressions.proto | 4 +-
.../planner/LiteralValueProtoConverter.scala | 3 +-
python/pyspark/sql/connect/expressions.py | 38 ++++++++++---
.../pyspark/sql/connect/proto/expressions_pb2.py | 66 +++++++++++-----------
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 16 +++---
.../sql/tests/connect/test_connect_column.py | 3 -
.../pyspark/sql/tests/connect/test_connect_plan.py | 42 ++++++++++++++
8 files changed, 118 insertions(+), 56 deletions(-)
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala
index b3b9f53e7bb..daddfa9b5af 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala
@@ -59,7 +59,7 @@ object LiteralProtoConverter {
def arrayBuilder(array: Array[_]) = {
val ab = builder.getArrayBuilder
.setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType)))
- array.foreach(x => ab.addElement(toLiteralProto(x)))
+ array.foreach(x => ab.addElements(toLiteralProto(x)))
ab
}
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 9e949dab15a..af67f10e05f 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -192,8 +192,8 @@ message Expression {
}
message Array {
- DataType elementType = 1;
- repeated Literal element = 2;
+ DataType element_type = 1;
+ repeated Literal elements = 2;
}
}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
index 79c489b9f5b..7a580913867 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
@@ -105,6 +105,7 @@ object LiteralValueProtoConverter {
expressions.Literal.create(
toArrayData(lit.getArray),
ArrayType(DataTypeProtoConverter.toCatalystType(lit.getArray.getElementType)))
+
case _ =>
throw InvalidPlanInput(
s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
@@ -143,7 +144,7 @@ object LiteralValueProtoConverter {
def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit
tag: ClassTag[T]): Array[T] = {
val builder = mutable.ArrayBuilder.make[T]
- val elementList = array.getElementList
+ val elementList = array.getElementsList
builder.sizeHint(elementList.size())
val iter = elementList.iterator()
while (iter.hasNext) {
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index 5c122f40373..dbf260382f7 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -41,6 +41,7 @@ from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.types import (
_from_numpy_type,
DateType,
+ ArrayType,
NullType,
BooleanType,
BinaryType,
@@ -193,6 +194,7 @@ class LiteralExpression(Expression):
TimestampType,
TimestampNTZType,
DayTimeIntervalType,
+ ArrayType,
),
)
@@ -247,6 +249,8 @@ class LiteralExpression(Expression):
assert isinstance(value, datetime.timedelta)
value = DayTimeIntervalType().toInternal(value)
assert value is not None
+ elif isinstance(dataType, ArrayType):
+ assert isinstance(value, list)
else:
raise TypeError(f"Unsupported Data Type {dataType}")
@@ -280,14 +284,25 @@ class LiteralExpression(Expression):
return DateType()
elif isinstance(value, datetime.timedelta):
return DayTimeIntervalType()
- else:
- if isinstance(value, np.generic):
- dt = _from_numpy_type(value.dtype)
- if dt is not None:
- return dt
- elif isinstance(value, np.bool_):
- return BooleanType()
- raise TypeError(f"Unsupported Data Type {type(value).__name__}")
+ elif isinstance(value, np.generic):
+ dt = _from_numpy_type(value.dtype)
+ if dt is not None:
+ return dt
+ elif isinstance(value, np.bool_):
+ return BooleanType()
+ elif isinstance(value, list):
+ # follow the 'infer_array_from_first_element' strategy in 'sql.types._infer_type'
+ # right now, it's dedicated for pyspark.ml params like array<...>, array<array<...>>
+ if len(value) == 0:
+ raise TypeError("Can not infer Array Type from an empty list")
+ first = value[0]
+ if first is None:
+ raise TypeError(
+ "Can not infer Array Type from an list with None as the first element"
+ )
+ return ArrayType(LiteralExpression._infer_type(first), True)
+
+ raise TypeError(f"Unsupported Data Type {type(value).__name__}")
@classmethod
def _from_value(cls, value: Any) -> "LiteralExpression":
@@ -330,6 +345,13 @@ class LiteralExpression(Expression):
expr.literal.timestamp_ntz = int(self._value)
elif isinstance(self._dataType, DayTimeIntervalType):
expr.literal.day_time_interval = int(self._value)
+ elif isinstance(self._dataType, ArrayType):
+ element_type = self._dataType.elementType
+ expr.literal.array.element_type.CopyFrom(pyspark_types_to_proto_types(element_type))
+ for v in self._value:
+ expr.literal.array.elements.append(
+ LiteralExpression(v, element_type).to_plan(session).literal
+ )
else:
raise ValueError(f"Unsupported Data Type {self._dataType}")
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 1814f52f539..f736a7927a7 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xa8\'\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunc [...]
+ b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xac\'\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunc [...]
)
@@ -323,7 +323,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
_EXPRESSION._serialized_start = 105
- _EXPRESSION._serialized_end = 5137
+ _EXPRESSION._serialized_end = 5141
_EXPRESSION_WINDOW._serialized_start = 1475
_EXPRESSION_WINDOW._serialized_end = 2258
_EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1765
@@ -341,39 +341,39 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_EXPRESSION_CAST._serialized_start = 2689
_EXPRESSION_CAST._serialized_end = 2834
_EXPRESSION_LITERAL._serialized_start = 2837
- _EXPRESSION_LITERAL._serialized_end = 3907
+ _EXPRESSION_LITERAL._serialized_end = 3911
_EXPRESSION_LITERAL_DECIMAL._serialized_start = 3545
_EXPRESSION_LITERAL_DECIMAL._serialized_end = 3662
_EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3664
_EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3762
- _EXPRESSION_LITERAL_ARRAY._serialized_start = 3764
- _EXPRESSION_LITERAL_ARRAY._serialized_end = 3891
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3909
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4021
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4024
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4228
- _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4230
- _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4280
- _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4282
- _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4364
- _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4366
- _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4452
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4455
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4587
- _EXPRESSION_UPDATEFIELDS._serialized_start = 4590
- _EXPRESSION_UPDATEFIELDS._serialized_end = 4777
- _EXPRESSION_ALIAS._serialized_start = 4779
- _EXPRESSION_ALIAS._serialized_end = 4899
- _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4902
- _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5060
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5062
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5124
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5140
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5504
- _PYTHONUDF._serialized_start = 5507
- _PYTHONUDF._serialized_end = 5662
- _SCALARSCALAUDF._serialized_start = 5665
- _SCALARSCALAUDF._serialized_end = 5849
- _JAVAUDF._serialized_start = 5852
- _JAVAUDF._serialized_end = 6001
+ _EXPRESSION_LITERAL_ARRAY._serialized_start = 3765
+ _EXPRESSION_LITERAL_ARRAY._serialized_end = 3895
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3913
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4025
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4028
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4232
+ _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4234
+ _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4284
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4286
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4368
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4370
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4456
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4459
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4591
+ _EXPRESSION_UPDATEFIELDS._serialized_start = 4594
+ _EXPRESSION_UPDATEFIELDS._serialized_end = 4781
+ _EXPRESSION_ALIAS._serialized_start = 4783
+ _EXPRESSION_ALIAS._serialized_end = 4903
+ _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4906
+ _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5064
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5066
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5128
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5144
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5508
+ _PYTHONUDF._serialized_start = 5511
+ _PYTHONUDF._serialized_end = 5666
+ _SCALARSCALAUDF._serialized_start = 5669
+ _SCALARSCALAUDF._serialized_end = 5853
+ _JAVAUDF._serialized_start = 5856
+ _JAVAUDF._serialized_end = 6005
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 3c8de8abb4e..16f84694d2f 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -443,12 +443,12 @@ class Expression(google.protobuf.message.Message):
class Array(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
- ELEMENTTYPE_FIELD_NUMBER: builtins.int
- ELEMENT_FIELD_NUMBER: builtins.int
+ ELEMENT_TYPE_FIELD_NUMBER: builtins.int
+ ELEMENTS_FIELD_NUMBER: builtins.int
@property
- def elementType(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
+ def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
@property
- def element(
+ def elements(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
global___Expression.Literal
@@ -456,16 +456,16 @@ class Expression(google.protobuf.message.Message):
def __init__(
self,
*,
- elementType: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
- element: collections.abc.Iterable[global___Expression.Literal] | None = ...,
+ element_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
+ elements: collections.abc.Iterable[global___Expression.Literal] | None = ...,
) -> None: ...
def HasField(
- self, field_name: typing_extensions.Literal["elementType", b"elementType"]
+ self, field_name: typing_extensions.Literal["element_type", b"element_type"]
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "element", b"element", "elementType", b"elementType"
+ "element_type", b"element_type", "elements", b"elements"
],
) -> None: ...
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py
index b5d8163f4f7..c9c715a3a61 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -23,7 +23,6 @@ from pyspark.sql.types import (
Row,
StructField,
StructType,
- ArrayType,
MapType,
NullType,
DateType,
@@ -437,7 +436,6 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
(0.1, DecimalType()),
(datetime.date(2022, 12, 13), TimestampType()),
(datetime.timedelta(1, 2, 3), DateType()),
- ([1, 2, 3], ArrayType(IntegerType())),
({1: 2}, MapType(IntegerType(), IntegerType())),
(
{"a": "xyz", "b": 1},
@@ -474,7 +472,6 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
for value, dataType in [
("123", NullType()),
(123, NullType()),
- (None, ArrayType(IntegerType())),
(None, MapType(IntegerType(), IntegerType())),
(None, StructType([StructField("a", StringType())])),
]:
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py
index d5cffa459d7..f627136650d 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -18,6 +18,7 @@ import unittest
import uuid
import datetime
import decimal
+import math
from pyspark.testing.connectutils import (
PlanOnlyTestFixture,
@@ -31,6 +32,7 @@ if should_test_connect:
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import WriteOperation, Read
from pyspark.sql.connect.readwriter import DataFrameReader
+ from pyspark.sql.connect.expressions import LiteralExpression
from pyspark.sql.connect.functions import col, lit, max, min, sum
from pyspark.sql.connect.types import pyspark_types_to_proto_types
from pyspark.sql.types import (
@@ -944,6 +946,46 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
mod_fun.unresolved_function.arguments[0].unresolved_attribute.unparsed_identifier, "id"
)
+ def test_literal_expression_with_arrays(self):
+ l0 = LiteralExpression._from_value(["x", "y", "z"]).to_plan(None).literal
+ self.assertTrue(l0.array.element_type.HasField("string"))
+ self.assertEqual(len(l0.array.elements), 3)
+ self.assertEqual(l0.array.elements[0].string, "x")
+ self.assertEqual(l0.array.elements[1].string, "y")
+ self.assertEqual(l0.array.elements[2].string, "z")
+
+ l1 = LiteralExpression._from_value([3, -3]).to_plan(None).literal
+ self.assertTrue(l1.array.element_type.HasField("integer"))
+ self.assertEqual(len(l1.array.elements), 2)
+ self.assertEqual(l1.array.elements[0].integer, 3)
+ self.assertEqual(l1.array.elements[1].integer, -3)
+
+ l2 = LiteralExpression._from_value([float("nan"), -3.0, 0.0]).to_plan(None).literal
+ self.assertTrue(l2.array.element_type.HasField("double"))
+ self.assertEqual(len(l2.array.elements), 3)
+ self.assertTrue(math.isnan(l2.array.elements[0].double))
+ self.assertEqual(l2.array.elements[1].double, -3.0)
+ self.assertEqual(l2.array.elements[2].double, 0.0)
+
+ l3 = LiteralExpression._from_value([[3, 4], [5, 6, 7]]).to_plan(None).literal
+ self.assertTrue(l3.array.element_type.HasField("array"))
+ self.assertTrue(l3.array.element_type.array.element_type.HasField("integer"))
+ self.assertEqual(len(l3.array.elements), 2)
+ self.assertEqual(len(l3.array.elements[0].array.elements), 2)
+ self.assertEqual(len(l3.array.elements[1].array.elements), 3)
+
+ l4 = (
+ LiteralExpression._from_value([[float("inf"), 0.4], [0.5, float("nan")], []])
+ .to_plan(None)
+ .literal
+ )
+ self.assertTrue(l4.array.element_type.HasField("array"))
+ self.assertTrue(l4.array.element_type.array.element_type.HasField("double"))
+ self.assertEqual(len(l4.array.elements), 3)
+ self.assertEqual(len(l4.array.elements[0].array.elements), 2)
+ self.assertEqual(len(l4.array.elements[1].array.elements), 2)
+ self.assertEqual(len(l4.array.elements[2].array.elements), 0)
+
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_plan import * # noqa: F401
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org