You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/11/04 14:57:24 UTC

[spark] branch master updated: [SPARK-40533][CONNECT][PYTHON] Support most built-in literal types for Python in Spark Connect

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f31f64d15bcb [SPARK-40533][CONNECT][PYTHON] Support most built-in literal types for Python in Spark Connect
f31f64d15bcb is described below

commit f31f64d15bcb06987c6d2301b107cfe6b24f825f
Author: Martin Grund <ma...@databricks.com>
AuthorDate: Fri Nov 4 23:57:12 2022 +0900

    [SPARK-40533][CONNECT][PYTHON] Support most built-in literal types for Python in Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR implements the client-side serialization of most Python literals into Spark Connect literals.
    
    ### Why are the changes needed?
    Expanding the Python client support.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
    
    Closes #38462 from grundprinzip/SPARK-40533.
    
    Lead-authored-by: Martin Grund <ma...@databricks.com>
    Co-authored-by: Martin Grund <gr...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/sql/connect/_typing.py              |  2 +
 python/pyspark/sql/connect/column.py               | 56 ++++++++++++++-
 python/pyspark/sql/connect/functions.py            |  5 +-
 .../connect/test_connect_column_expressions.py     | 83 +++++++++++++++++++++-
 4 files changed, 138 insertions(+), 8 deletions(-)

diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py
index 5cd14111bada..4e69b2e4aa5e 100644
--- a/python/pyspark/sql/connect/_typing.py
+++ b/python/pyspark/sql/connect/_typing.py
@@ -15,5 +15,7 @@
 # limitations under the License.
 #
 from typing import Union
+from datetime import date, time, datetime
 
 PrimitiveType = Union[str, int, bool, float]
+LiteralType = Union[PrimitiveType, Union[date, time, datetime]]
diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py
index 126c45d6b4a8..42466fa16992 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -14,9 +14,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import uuid
 from typing import cast, get_args, TYPE_CHECKING, Optional, Callable, Any
 
+import decimal
+import datetime
 
 import pyspark.sql.connect.proto as proto
 from pyspark.sql.connect._typing import PrimitiveType
@@ -87,7 +89,7 @@ class LiteralExpression(Expression):
     The Python types are converted best effort into the relevant proto types. On the Spark Connect
     server side, the proto types are converted to the Catalyst equivalents."""
 
-    def __init__(self, value: PrimitiveType) -> None:
+    def __init__(self, value: Any) -> None:
         super().__init__()
         self._value = value
 
@@ -99,11 +101,59 @@ class LiteralExpression(Expression):
         value_type = type(self._value)
         exp = proto.Expression()
         if value_type is int:
-            exp.literal.i32 = cast(int, self._value)
+            exp.literal.i64 = cast(int, self._value)
+        elif value_type is bool:
+            exp.literal.boolean = cast(bool, self._value)
         elif value_type is str:
             exp.literal.string = cast(str, self._value)
         elif value_type is float:
             exp.literal.fp64 = cast(float, self._value)
+        elif value_type is decimal.Decimal:
+            d_v = cast(decimal.Decimal, self._value)
+            v_tuple = d_v.as_tuple()
+            exp.literal.decimal.scale = abs(v_tuple.exponent)
+            exp.literal.decimal.precision = len(v_tuple.digits) - abs(v_tuple.exponent)
+            # Two complement yeah...
+            raise ValueError("Python Decimal not supported.")
+        elif value_type is bytes:
+            exp.literal.binary = self._value
+        elif value_type is datetime.datetime:
+            # Microseconds since epoch.
+            dt = cast(datetime.datetime, self._value)
+            v = dt - datetime.datetime(1970, 1, 1, 0, 0, 0, 0)
+            exp.literal.timestamp = int(v / datetime.timedelta(microseconds=1))
+        elif value_type is datetime.time:
+            # Nanoseconds of the day.
+            tv = cast(datetime.time, self._value)
+            offset = (tv.second + tv.minute * 60 + tv.hour * 3600) * 1000 + tv.microsecond
+            exp.literal.time = int(offset * 1000)
+        elif value_type is datetime.date:
+            # Days since epoch.
+            days_since_epoch = (cast(datetime.date, self._value) - datetime.date(1970, 1, 1)).days
+            exp.literal.date = days_since_epoch
+        elif value_type is uuid.UUID:
+            raise ValueError("Python UUID type not supported.")
+        elif value_type is list:
+            lv = cast(list, self._value)
+            for k in lv:
+                if type(k) is LiteralExpression:
+                    exp.literal.list.values.append(k.to_plan(session).literal)
+                else:
+                    exp.literal.list.values.append(LiteralExpression(k).to_plan(session).literal)
+        elif value_type is dict:
+            mv = cast(dict, self._value)
+            for k in mv:
+                kv = proto.Expression.Literal.Map.KeyValue()
+                if type(k) is LiteralExpression:
+                    kv.key.CopyFrom(k.to_plan(session).literal)
+                else:
+                    kv.key.CopyFrom(LiteralExpression(k).to_plan(session).literal)
+
+                if type(mv[k]) is LiteralExpression:
+                    kv.value.CopyFrom(mv[k].to_plan(session).literal)
+                else:
+                    kv.value.CopyFrom(LiteralExpression(mv[k]).to_plan(session).literal)
+                exp.literal.map.key_values.append(kv)
         else:
             raise ValueError(f"Could not convert literal for type {type(self._value)}")
 
diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py
index 4fe57d922837..880096da4598 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -15,7 +15,8 @@
 # limitations under the License.
 #
 from pyspark.sql.connect.column import ColumnRef, LiteralExpression
-from pyspark.sql.connect.column import PrimitiveType
+
+from typing import Any
 
 # TODO(SPARK-40538) Add support for the missing PySpark functions.
 
@@ -24,5 +25,5 @@ def col(x: str) -> ColumnRef:
     return ColumnRef(x)
 
 
-def lit(x: PrimitiveType) -> LiteralExpression:
+def lit(x: Any) -> LiteralExpression:
     return LiteralExpression(x)
diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
index 790a987e8809..8773fe4aceba 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
@@ -14,9 +14,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import uuid
 from typing import cast
 import unittest
+import decimal
+import datetime
+
 from pyspark.testing.connectutils import PlanOnlyTestFixture
 from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message
 
@@ -49,6 +52,32 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
         self.assertEqual(cp1, cp2)
         self.assertEqual(cp2, cp3)
 
+    def test_binary_literal(self):
+        val = b"binary\0\0asas"
+        bin_lit = fun.lit(val)
+        bin_lit_p = bin_lit.to_plan(None)
+        self.assertEqual(bin_lit_p.literal.binary, val)
+
+    def test_map_literal(self):
+        val = {"this": "is", 12: [12, 32, 43]}
+        map_lit = fun.lit(val)
+        map_lit_p = map_lit.to_plan(None)
+        self.assertEqual(2, len(map_lit_p.literal.map.key_values))
+        self.assertEqual("this", map_lit_p.literal.map.key_values[0].key.string)
+        self.assertEqual(12, map_lit_p.literal.map.key_values[1].key.i64)
+
+        val = {"this": fun.lit("is"), 12: [12, 32, 43]}
+        map_lit = fun.lit(val)
+        map_lit_p = map_lit.to_plan(None)
+        self.assertEqual(2, len(map_lit_p.literal.map.key_values))
+        self.assertEqual("is", map_lit_p.literal.map.key_values[0].value.string)
+
+    def test_uuid_literal(self):
+        val = uuid.uuid4()
+        lit = fun.lit(val)
+        with self.assertRaises(ValueError):
+            lit.to_plan(None)
+
     def test_column_literals(self):
         df = c.DataFrame.withPlan(p.Read("table"))
         lit_df = df.select(fun.lit(10))
@@ -56,7 +85,55 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
 
         self.assertIsNotNone(fun.lit(10).to_plan(None))
         plan = fun.lit(10).to_plan(None)
-        self.assertIs(plan.literal.i32, 10)
+        self.assertIs(plan.literal.i64, 10)
+
+    def test_numeric_literal_types(self):
+        int_lit = fun.lit(10)
+        float_lit = fun.lit(10.1)
+        decimal_lit = fun.lit(decimal.Decimal(99))
+
+        # Decimal is not supported yet.
+        with self.assertRaises(ValueError):
+            self.assertIsNotNone(decimal_lit.to_plan(None))
+
+        self.assertIsNotNone(int_lit.to_plan(None))
+        self.assertIsNotNone(float_lit.to_plan(None))
+
+    def test_datetime_literal_types(self):
+        """Test the different timestamp, date, and time types."""
+        datetime_lit = fun.lit(datetime.datetime.now())
+
+        p = datetime_lit.to_plan(None)
+        self.assertIsNotNone(datetime_lit.to_plan(None))
+        self.assertGreater(p.literal.timestamp, 10000000000000)
+
+        date_lit = fun.lit(datetime.date.today())
+        time_lit = fun.lit(datetime.time())
+
+        self.assertIsNotNone(date_lit.to_plan(None))
+        self.assertIsNotNone(time_lit.to_plan(None))
+
+    def test_list_to_literal(self):
+        """Test conversion of lists to literals"""
+        empty_list = []
+        single_type = [1, 2, 3, 4]
+        multi_type = ["ooo", 1, "asas", 2.3]
+
+        empty_list_lit = fun.lit(empty_list)
+        single_type_lit = fun.lit(single_type)
+        multi_type_lit = fun.lit(multi_type)
+
+        p = empty_list_lit.to_plan(None)
+        self.assertIsNotNone(p)
+
+        p = single_type_lit.to_plan(None)
+        self.assertIsNotNone(p)
+
+        p = multi_type_lit.to_plan(None)
+        self.assertIsNotNone(p)
+
+        lit_list_plan = fun.lit([fun.lit(10), fun.lit("str")]).to_plan(None)
+        self.assertIsNotNone(lit_list_plan)
 
     def test_column_expressions(self):
         """Test a more complex combination of expressions and their translation into
@@ -76,7 +153,7 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
         lit_fun = expr_plan.unresolved_function.arguments[1]
         self.assertIsInstance(lit_fun, ProtoExpression)
         self.assertIsInstance(lit_fun.literal, ProtoExpression.Literal)
-        self.assertEqual(lit_fun.literal.i32, 10)
+        self.assertEqual(lit_fun.literal.i64, 10)
 
         mod_fun = expr_plan.unresolved_function.arguments[0]
         self.assertIsInstance(mod_fun, ProtoExpression)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org