You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/11/04 14:57:24 UTC
[spark] branch master updated: [SPARK-40533][CONNECT][PYTHON] Support most built-in literal types for Python in Spark Connect
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f31f64d15bcb [SPARK-40533][CONNECT][PYTHON] Support most built-in literal types for Python in Spark Connect
f31f64d15bcb is described below
commit f31f64d15bcb06987c6d2301b107cfe6b24f825f
Author: Martin Grund <ma...@databricks.com>
AuthorDate: Fri Nov 4 23:57:12 2022 +0900
[SPARK-40533][CONNECT][PYTHON] Support most built-in literal types for Python in Spark Connect
### What changes were proposed in this pull request?
This PR implements the client-side serialization of most Python literals into Spark Connect literals.
### Why are the changes needed?
Expanding the Python client support.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
Closes #38462 from grundprinzip/SPARK-40533.
Lead-authored-by: Martin Grund <ma...@databricks.com>
Co-authored-by: Martin Grund <gr...@gmail.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/sql/connect/_typing.py | 2 +
python/pyspark/sql/connect/column.py | 56 ++++++++++++++-
python/pyspark/sql/connect/functions.py | 5 +-
.../connect/test_connect_column_expressions.py | 83 +++++++++++++++++++++-
4 files changed, 138 insertions(+), 8 deletions(-)
diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py
index 5cd14111bada..4e69b2e4aa5e 100644
--- a/python/pyspark/sql/connect/_typing.py
+++ b/python/pyspark/sql/connect/_typing.py
@@ -15,5 +15,7 @@
# limitations under the License.
#
from typing import Union
+from datetime import date, time, datetime
PrimitiveType = Union[str, int, bool, float]
+LiteralType = Union[PrimitiveType, Union[date, time, datetime]]
diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py
index 126c45d6b4a8..42466fa16992 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
+import uuid
from typing import cast, get_args, TYPE_CHECKING, Optional, Callable, Any
+import decimal
+import datetime
import pyspark.sql.connect.proto as proto
from pyspark.sql.connect._typing import PrimitiveType
@@ -87,7 +89,7 @@ class LiteralExpression(Expression):
The Python types are converted best effort into the relevant proto types. On the Spark Connect
server side, the proto types are converted to the Catalyst equivalents."""
- def __init__(self, value: PrimitiveType) -> None:
+ def __init__(self, value: Any) -> None:
super().__init__()
self._value = value
@@ -99,11 +101,59 @@ class LiteralExpression(Expression):
value_type = type(self._value)
exp = proto.Expression()
if value_type is int:
- exp.literal.i32 = cast(int, self._value)
+ exp.literal.i64 = cast(int, self._value)
+ elif value_type is bool:
+ exp.literal.boolean = cast(bool, self._value)
elif value_type is str:
exp.literal.string = cast(str, self._value)
elif value_type is float:
exp.literal.fp64 = cast(float, self._value)
+ elif value_type is decimal.Decimal:
+ d_v = cast(decimal.Decimal, self._value)
+ v_tuple = d_v.as_tuple()
+ exp.literal.decimal.scale = abs(v_tuple.exponent)
+ exp.literal.decimal.precision = len(v_tuple.digits) - abs(v_tuple.exponent)
+ # Two complement yeah...
+ raise ValueError("Python Decimal not supported.")
+ elif value_type is bytes:
+ exp.literal.binary = self._value
+ elif value_type is datetime.datetime:
+ # Microseconds since epoch.
+ dt = cast(datetime.datetime, self._value)
+ v = dt - datetime.datetime(1970, 1, 1, 0, 0, 0, 0)
+ exp.literal.timestamp = int(v / datetime.timedelta(microseconds=1))
+ elif value_type is datetime.time:
+ # Nanoseconds of the day.
+ tv = cast(datetime.time, self._value)
+ offset = (tv.second + tv.minute * 60 + tv.hour * 3600) * 1000 + tv.microsecond
+ exp.literal.time = int(offset * 1000)
+ elif value_type is datetime.date:
+ # Days since epoch.
+ days_since_epoch = (cast(datetime.date, self._value) - datetime.date(1970, 1, 1)).days
+ exp.literal.date = days_since_epoch
+ elif value_type is uuid.UUID:
+ raise ValueError("Python UUID type not supported.")
+ elif value_type is list:
+ lv = cast(list, self._value)
+ for k in lv:
+ if type(k) is LiteralExpression:
+ exp.literal.list.values.append(k.to_plan(session).literal)
+ else:
+ exp.literal.list.values.append(LiteralExpression(k).to_plan(session).literal)
+ elif value_type is dict:
+ mv = cast(dict, self._value)
+ for k in mv:
+ kv = proto.Expression.Literal.Map.KeyValue()
+ if type(k) is LiteralExpression:
+ kv.key.CopyFrom(k.to_plan(session).literal)
+ else:
+ kv.key.CopyFrom(LiteralExpression(k).to_plan(session).literal)
+
+ if type(mv[k]) is LiteralExpression:
+ kv.value.CopyFrom(mv[k].to_plan(session).literal)
+ else:
+ kv.value.CopyFrom(LiteralExpression(mv[k]).to_plan(session).literal)
+ exp.literal.map.key_values.append(kv)
else:
raise ValueError(f"Could not convert literal for type {type(self._value)}")
diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py
index 4fe57d922837..880096da4598 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -15,7 +15,8 @@
# limitations under the License.
#
from pyspark.sql.connect.column import ColumnRef, LiteralExpression
-from pyspark.sql.connect.column import PrimitiveType
+
+from typing import Any
# TODO(SPARK-40538) Add support for the missing PySpark functions.
@@ -24,5 +25,5 @@ def col(x: str) -> ColumnRef:
return ColumnRef(x)
-def lit(x: PrimitiveType) -> LiteralExpression:
+def lit(x: Any) -> LiteralExpression:
return LiteralExpression(x)
diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
index 790a987e8809..8773fe4aceba 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
@@ -14,9 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
+import uuid
from typing import cast
import unittest
+import decimal
+import datetime
+
from pyspark.testing.connectutils import PlanOnlyTestFixture
from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message
@@ -49,6 +52,32 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
self.assertEqual(cp1, cp2)
self.assertEqual(cp2, cp3)
+ def test_binary_literal(self):
+ val = b"binary\0\0asas"
+ bin_lit = fun.lit(val)
+ bin_lit_p = bin_lit.to_plan(None)
+ self.assertEqual(bin_lit_p.literal.binary, val)
+
+ def test_map_literal(self):
+ val = {"this": "is", 12: [12, 32, 43]}
+ map_lit = fun.lit(val)
+ map_lit_p = map_lit.to_plan(None)
+ self.assertEqual(2, len(map_lit_p.literal.map.key_values))
+ self.assertEqual("this", map_lit_p.literal.map.key_values[0].key.string)
+ self.assertEqual(12, map_lit_p.literal.map.key_values[1].key.i64)
+
+ val = {"this": fun.lit("is"), 12: [12, 32, 43]}
+ map_lit = fun.lit(val)
+ map_lit_p = map_lit.to_plan(None)
+ self.assertEqual(2, len(map_lit_p.literal.map.key_values))
+ self.assertEqual("is", map_lit_p.literal.map.key_values[0].value.string)
+
+ def test_uuid_literal(self):
+ val = uuid.uuid4()
+ lit = fun.lit(val)
+ with self.assertRaises(ValueError):
+ lit.to_plan(None)
+
def test_column_literals(self):
df = c.DataFrame.withPlan(p.Read("table"))
lit_df = df.select(fun.lit(10))
@@ -56,7 +85,55 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
self.assertIsNotNone(fun.lit(10).to_plan(None))
plan = fun.lit(10).to_plan(None)
- self.assertIs(plan.literal.i32, 10)
+ self.assertIs(plan.literal.i64, 10)
+
+ def test_numeric_literal_types(self):
+ int_lit = fun.lit(10)
+ float_lit = fun.lit(10.1)
+ decimal_lit = fun.lit(decimal.Decimal(99))
+
+ # Decimal is not supported yet.
+ with self.assertRaises(ValueError):
+ self.assertIsNotNone(decimal_lit.to_plan(None))
+
+ self.assertIsNotNone(int_lit.to_plan(None))
+ self.assertIsNotNone(float_lit.to_plan(None))
+
+ def test_datetime_literal_types(self):
+ """Test the different timestamp, date, and time types."""
+ datetime_lit = fun.lit(datetime.datetime.now())
+
+ p = datetime_lit.to_plan(None)
+ self.assertIsNotNone(datetime_lit.to_plan(None))
+ self.assertGreater(p.literal.timestamp, 10000000000000)
+
+ date_lit = fun.lit(datetime.date.today())
+ time_lit = fun.lit(datetime.time())
+
+ self.assertIsNotNone(date_lit.to_plan(None))
+ self.assertIsNotNone(time_lit.to_plan(None))
+
+ def test_list_to_literal(self):
+ """Test conversion of lists to literals"""
+ empty_list = []
+ single_type = [1, 2, 3, 4]
+ multi_type = ["ooo", 1, "asas", 2.3]
+
+ empty_list_lit = fun.lit(empty_list)
+ single_type_lit = fun.lit(single_type)
+ multi_type_lit = fun.lit(multi_type)
+
+ p = empty_list_lit.to_plan(None)
+ self.assertIsNotNone(p)
+
+ p = single_type_lit.to_plan(None)
+ self.assertIsNotNone(p)
+
+ p = multi_type_lit.to_plan(None)
+ self.assertIsNotNone(p)
+
+ lit_list_plan = fun.lit([fun.lit(10), fun.lit("str")]).to_plan(None)
+ self.assertIsNotNone(lit_list_plan)
def test_column_expressions(self):
"""Test a more complex combination of expressions and their translation into
@@ -76,7 +153,7 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
lit_fun = expr_plan.unresolved_function.arguments[1]
self.assertIsInstance(lit_fun, ProtoExpression)
self.assertIsInstance(lit_fun.literal, ProtoExpression.Literal)
- self.assertEqual(lit_fun.literal.i32, 10)
+ self.assertEqual(lit_fun.literal.i64, 10)
mod_fun = expr_plan.unresolved_function.arguments[0]
self.assertIsInstance(mod_fun, ProtoExpression)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org