You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/03/10 02:19:15 UTC

[spark] branch branch-3.4 updated: [SPARK-42725][CONNECT][PYTHON] Make LiteralExpression support array params

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new a01f4d6ac8e [SPARK-42725][CONNECT][PYTHON] Make LiteralExpression support array params
a01f4d6ac8e is described below

commit a01f4d6ac8eb228fef79b21eb94235a64cceaa4d
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Fri Mar 10 10:18:44 2023 +0800

    [SPARK-42725][CONNECT][PYTHON] Make LiteralExpression support array params
    
    ### What changes were proposed in this pull request?
    Make LiteralExpression support array
    
    ### Why are the changes needed?
    MLIib requires literal to carry the array params, like  `IntArrayParam`, `DoubleArrayArrayParam`.
    
    Note that this PR doesn't affect existing `functions.lit` method which apply unresolved `CreateArray` expression to support array input.
    
    ### Does this PR introduce _any_ user-facing change?
    No, dev-only
    
    ### How was this patch tested?
    added UT
    
    Closes #40349 from zhengruifeng/connect_py_ml_lit.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
    (cherry picked from commit d6d0fc74d36567c5163878656de787d6fb418604)
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../sql/expressions/LiteralProtoConverter.scala    |  2 +-
 .../main/protobuf/spark/connect/expressions.proto  |  4 +-
 .../planner/LiteralValueProtoConverter.scala       |  3 +-
 python/pyspark/sql/connect/expressions.py          | 38 ++++++++++---
 .../pyspark/sql/connect/proto/expressions_pb2.py   | 66 +++++++++++-----------
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  | 16 +++---
 .../sql/tests/connect/test_connect_column.py       |  3 -
 .../pyspark/sql/tests/connect/test_connect_plan.py | 42 ++++++++++++++
 8 files changed, 118 insertions(+), 56 deletions(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala
index b3b9f53e7bb..daddfa9b5af 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/LiteralProtoConverter.scala
@@ -59,7 +59,7 @@ object LiteralProtoConverter {
     def arrayBuilder(array: Array[_]) = {
       val ab = builder.getArrayBuilder
         .setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType)))
-      array.foreach(x => ab.addElement(toLiteralProto(x)))
+      array.foreach(x => ab.addElements(toLiteralProto(x)))
       ab
     }
 
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 9e949dab15a..af67f10e05f 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -192,8 +192,8 @@ message Expression {
     }
 
     message Array {
-      DataType elementType = 1;
-      repeated Literal element = 2;
+      DataType element_type = 1;
+      repeated Literal elements = 2;
     }
   }
 
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
index 79c489b9f5b..7a580913867 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralValueProtoConverter.scala
@@ -105,6 +105,7 @@ object LiteralValueProtoConverter {
         expressions.Literal.create(
           toArrayData(lit.getArray),
           ArrayType(DataTypeProtoConverter.toCatalystType(lit.getArray.getElementType)))
+
       case _ =>
         throw InvalidPlanInput(
           s"Unsupported Literal Type: ${lit.getLiteralTypeCase.getNumber}" +
@@ -143,7 +144,7 @@ object LiteralValueProtoConverter {
     def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit
         tag: ClassTag[T]): Array[T] = {
       val builder = mutable.ArrayBuilder.make[T]
-      val elementList = array.getElementList
+      val elementList = array.getElementsList
       builder.sizeHint(elementList.size())
       val iter = elementList.iterator()
       while (iter.hasNext) {
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index 5c122f40373..dbf260382f7 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -41,6 +41,7 @@ from pyspark.serializers import CloudPickleSerializer
 from pyspark.sql.types import (
     _from_numpy_type,
     DateType,
+    ArrayType,
     NullType,
     BooleanType,
     BinaryType,
@@ -193,6 +194,7 @@ class LiteralExpression(Expression):
                 TimestampType,
                 TimestampNTZType,
                 DayTimeIntervalType,
+                ArrayType,
             ),
         )
 
@@ -247,6 +249,8 @@ class LiteralExpression(Expression):
                 assert isinstance(value, datetime.timedelta)
                 value = DayTimeIntervalType().toInternal(value)
                 assert value is not None
+            elif isinstance(dataType, ArrayType):
+                assert isinstance(value, list)
             else:
                 raise TypeError(f"Unsupported Data Type {dataType}")
 
@@ -280,14 +284,25 @@ class LiteralExpression(Expression):
             return DateType()
         elif isinstance(value, datetime.timedelta):
             return DayTimeIntervalType()
-        else:
-            if isinstance(value, np.generic):
-                dt = _from_numpy_type(value.dtype)
-                if dt is not None:
-                    return dt
-                elif isinstance(value, np.bool_):
-                    return BooleanType()
-            raise TypeError(f"Unsupported Data Type {type(value).__name__}")
+        elif isinstance(value, np.generic):
+            dt = _from_numpy_type(value.dtype)
+            if dt is not None:
+                return dt
+            elif isinstance(value, np.bool_):
+                return BooleanType()
+        elif isinstance(value, list):
+            # follow the 'infer_array_from_first_element' strategy in 'sql.types._infer_type'
+            # right now, it's dedicated for pyspark.ml params like array<...>, array<array<...>>
+            if len(value) == 0:
+                raise TypeError("Can not infer Array Type from an empty list")
+            first = value[0]
+            if first is None:
+                raise TypeError(
+                    "Can not infer Array Type from an list with None as the first element"
+                )
+            return ArrayType(LiteralExpression._infer_type(first), True)
+
+        raise TypeError(f"Unsupported Data Type {type(value).__name__}")
 
     @classmethod
     def _from_value(cls, value: Any) -> "LiteralExpression":
@@ -330,6 +345,13 @@ class LiteralExpression(Expression):
             expr.literal.timestamp_ntz = int(self._value)
         elif isinstance(self._dataType, DayTimeIntervalType):
             expr.literal.day_time_interval = int(self._value)
+        elif isinstance(self._dataType, ArrayType):
+            element_type = self._dataType.elementType
+            expr.literal.array.element_type.CopyFrom(pyspark_types_to_proto_types(element_type))
+            for v in self._value:
+                expr.literal.array.elements.append(
+                    LiteralExpression(v, element_type).to_plan(session).literal
+                )
         else:
             raise ValueError(f"Unsupported Data Type {self._dataType}")
 
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 1814f52f539..f736a7927a7 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xa8\'\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunc [...]
+    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xac\'\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunc [...]
 )
 
 
@@ -323,7 +323,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
     _EXPRESSION._serialized_start = 105
-    _EXPRESSION._serialized_end = 5137
+    _EXPRESSION._serialized_end = 5141
     _EXPRESSION_WINDOW._serialized_start = 1475
     _EXPRESSION_WINDOW._serialized_end = 2258
     _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1765
@@ -341,39 +341,39 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXPRESSION_CAST._serialized_start = 2689
     _EXPRESSION_CAST._serialized_end = 2834
     _EXPRESSION_LITERAL._serialized_start = 2837
-    _EXPRESSION_LITERAL._serialized_end = 3907
+    _EXPRESSION_LITERAL._serialized_end = 3911
     _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3545
     _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3662
     _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3664
     _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3762
-    _EXPRESSION_LITERAL_ARRAY._serialized_start = 3764
-    _EXPRESSION_LITERAL_ARRAY._serialized_end = 3891
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3909
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4021
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4024
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4228
-    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4230
-    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4280
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4282
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4364
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4366
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4452
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4455
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4587
-    _EXPRESSION_UPDATEFIELDS._serialized_start = 4590
-    _EXPRESSION_UPDATEFIELDS._serialized_end = 4777
-    _EXPRESSION_ALIAS._serialized_start = 4779
-    _EXPRESSION_ALIAS._serialized_end = 4899
-    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4902
-    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5060
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5062
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5124
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5140
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5504
-    _PYTHONUDF._serialized_start = 5507
-    _PYTHONUDF._serialized_end = 5662
-    _SCALARSCALAUDF._serialized_start = 5665
-    _SCALARSCALAUDF._serialized_end = 5849
-    _JAVAUDF._serialized_start = 5852
-    _JAVAUDF._serialized_end = 6001
+    _EXPRESSION_LITERAL_ARRAY._serialized_start = 3765
+    _EXPRESSION_LITERAL_ARRAY._serialized_end = 3895
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3913
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4025
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4028
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4232
+    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4234
+    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4284
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4286
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4368
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4370
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4456
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4459
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4591
+    _EXPRESSION_UPDATEFIELDS._serialized_start = 4594
+    _EXPRESSION_UPDATEFIELDS._serialized_end = 4781
+    _EXPRESSION_ALIAS._serialized_start = 4783
+    _EXPRESSION_ALIAS._serialized_end = 4903
+    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4906
+    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5064
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5066
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5128
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5144
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5508
+    _PYTHONUDF._serialized_start = 5511
+    _PYTHONUDF._serialized_end = 5666
+    _SCALARSCALAUDF._serialized_start = 5669
+    _SCALARSCALAUDF._serialized_end = 5853
+    _JAVAUDF._serialized_start = 5856
+    _JAVAUDF._serialized_end = 6005
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 3c8de8abb4e..16f84694d2f 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -443,12 +443,12 @@ class Expression(google.protobuf.message.Message):
         class Array(google.protobuf.message.Message):
             DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
-            ELEMENTTYPE_FIELD_NUMBER: builtins.int
-            ELEMENT_FIELD_NUMBER: builtins.int
+            ELEMENT_TYPE_FIELD_NUMBER: builtins.int
+            ELEMENTS_FIELD_NUMBER: builtins.int
             @property
-            def elementType(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
+            def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
             @property
-            def element(
+            def elements(
                 self,
             ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
                 global___Expression.Literal
@@ -456,16 +456,16 @@ class Expression(google.protobuf.message.Message):
             def __init__(
                 self,
                 *,
-                elementType: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
-                element: collections.abc.Iterable[global___Expression.Literal] | None = ...,
+                element_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
+                elements: collections.abc.Iterable[global___Expression.Literal] | None = ...,
             ) -> None: ...
             def HasField(
-                self, field_name: typing_extensions.Literal["elementType", b"elementType"]
+                self, field_name: typing_extensions.Literal["element_type", b"element_type"]
             ) -> builtins.bool: ...
             def ClearField(
                 self,
                 field_name: typing_extensions.Literal[
-                    "element", b"element", "elementType", b"elementType"
+                    "element_type", b"element_type", "elements", b"elements"
                 ],
             ) -> None: ...
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py
index b5d8163f4f7..c9c715a3a61 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -23,7 +23,6 @@ from pyspark.sql.types import (
     Row,
     StructField,
     StructType,
-    ArrayType,
     MapType,
     NullType,
     DateType,
@@ -437,7 +436,6 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
             (0.1, DecimalType()),
             (datetime.date(2022, 12, 13), TimestampType()),
             (datetime.timedelta(1, 2, 3), DateType()),
-            ([1, 2, 3], ArrayType(IntegerType())),
             ({1: 2}, MapType(IntegerType(), IntegerType())),
             (
                 {"a": "xyz", "b": 1},
@@ -474,7 +472,6 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
         for value, dataType in [
             ("123", NullType()),
             (123, NullType()),
-            (None, ArrayType(IntegerType())),
             (None, MapType(IntegerType(), IntegerType())),
             (None, StructType([StructField("a", StringType())])),
         ]:
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py
index d5cffa459d7..f627136650d 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -18,6 +18,7 @@ import unittest
 import uuid
 import datetime
 import decimal
+import math
 
 from pyspark.testing.connectutils import (
     PlanOnlyTestFixture,
@@ -31,6 +32,7 @@ if should_test_connect:
     from pyspark.sql.connect.dataframe import DataFrame
     from pyspark.sql.connect.plan import WriteOperation, Read
     from pyspark.sql.connect.readwriter import DataFrameReader
+    from pyspark.sql.connect.expressions import LiteralExpression
     from pyspark.sql.connect.functions import col, lit, max, min, sum
     from pyspark.sql.connect.types import pyspark_types_to_proto_types
     from pyspark.sql.types import (
@@ -944,6 +946,46 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
             mod_fun.unresolved_function.arguments[0].unresolved_attribute.unparsed_identifier, "id"
         )
 
+    def test_literal_expression_with_arrays(self):
+        l0 = LiteralExpression._from_value(["x", "y", "z"]).to_plan(None).literal
+        self.assertTrue(l0.array.element_type.HasField("string"))
+        self.assertEqual(len(l0.array.elements), 3)
+        self.assertEqual(l0.array.elements[0].string, "x")
+        self.assertEqual(l0.array.elements[1].string, "y")
+        self.assertEqual(l0.array.elements[2].string, "z")
+
+        l1 = LiteralExpression._from_value([3, -3]).to_plan(None).literal
+        self.assertTrue(l1.array.element_type.HasField("integer"))
+        self.assertEqual(len(l1.array.elements), 2)
+        self.assertEqual(l1.array.elements[0].integer, 3)
+        self.assertEqual(l1.array.elements[1].integer, -3)
+
+        l2 = LiteralExpression._from_value([float("nan"), -3.0, 0.0]).to_plan(None).literal
+        self.assertTrue(l2.array.element_type.HasField("double"))
+        self.assertEqual(len(l2.array.elements), 3)
+        self.assertTrue(math.isnan(l2.array.elements[0].double))
+        self.assertEqual(l2.array.elements[1].double, -3.0)
+        self.assertEqual(l2.array.elements[2].double, 0.0)
+
+        l3 = LiteralExpression._from_value([[3, 4], [5, 6, 7]]).to_plan(None).literal
+        self.assertTrue(l3.array.element_type.HasField("array"))
+        self.assertTrue(l3.array.element_type.array.element_type.HasField("integer"))
+        self.assertEqual(len(l3.array.elements), 2)
+        self.assertEqual(len(l3.array.elements[0].array.elements), 2)
+        self.assertEqual(len(l3.array.elements[1].array.elements), 3)
+
+        l4 = (
+            LiteralExpression._from_value([[float("inf"), 0.4], [0.5, float("nan")], []])
+            .to_plan(None)
+            .literal
+        )
+        self.assertTrue(l4.array.element_type.HasField("array"))
+        self.assertTrue(l4.array.element_type.array.element_type.HasField("double"))
+        self.assertEqual(len(l4.array.elements), 3)
+        self.assertEqual(len(l4.array.elements[0].array.elements), 2)
+        self.assertEqual(len(l4.array.elements[1].array.elements), 2)
+        self.assertEqual(len(l4.array.elements[2].array.elements), 0)
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.test_connect_plan import *  # noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org