You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/03/13 06:31:35 UTC
[spark] branch master updated: [SPARK-42756][CONNECT][PYTHON] Helper function to convert proto literal to value in Python Client
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 617c3554b27 [SPARK-42756][CONNECT][PYTHON] Helper function to convert proto literal to value in Python Client
617c3554b27 is described below
commit 617c3554b2737a3cc3f9edc8e2685e94662c5251
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Mon Mar 13 14:31:16 2023 +0800
[SPARK-42756][CONNECT][PYTHON] Helper function to convert proto literal to value in Python Client
### What changes were proposed in this pull request?
Helper function to convert proto literal to value in Python Client
### Why are the changes needed?
needed in .ml
### Does this PR introduce _any_ user-facing change?
no, dev-only
### How was this patch tested?
added ut
Closes #40376 from zhengruifeng/connect_literal_to_value.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
python/pyspark/sql/connect/expressions.py | 58 ++++++++++++++++++++++
.../pyspark/sql/tests/connect/test_connect_plan.py | 50 +++++++++++++++++++
2 files changed, 108 insertions(+)
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index dbf260382f7..0e0aa49cda8 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -71,6 +71,7 @@ from pyspark.sql.connect.types import (
JVM_LONG_MAX,
UnparsedDataType,
pyspark_types_to_proto_types,
+ proto_schema_to_pyspark_data_type,
)
if TYPE_CHECKING:
@@ -308,6 +309,63 @@ class LiteralExpression(Expression):
def _from_value(cls, value: Any) -> "LiteralExpression":
return LiteralExpression(value=value, dataType=LiteralExpression._infer_type(value))
+ @classmethod
+ def _to_value(
+ cls, literal: "proto.Expression.Literal", dataType: Optional[DataType] = None
+ ) -> Any:
+ if literal.HasField("null"):
+ return None
+ elif literal.HasField("binary"):
+ assert dataType is None or isinstance(dataType, BinaryType)
+ return literal.binary
+ elif literal.HasField("boolean"):
+ assert dataType is None or isinstance(dataType, BooleanType)
+ return literal.boolean
+ elif literal.HasField("byte"):
+ assert dataType is None or isinstance(dataType, ByteType)
+ return literal.byte
+ elif literal.HasField("short"):
+ assert dataType is None or isinstance(dataType, ShortType)
+ return literal.short
+ elif literal.HasField("integer"):
+ assert dataType is None or isinstance(dataType, IntegerType)
+ return literal.integer
+ elif literal.HasField("long"):
+ assert dataType is None or isinstance(dataType, LongType)
+ return literal.long
+ elif literal.HasField("float"):
+ assert dataType is None or isinstance(dataType, FloatType)
+ return literal.float
+ elif literal.HasField("double"):
+ assert dataType is None or isinstance(dataType, DoubleType)
+ return literal.double
+ elif literal.HasField("decimal"):
+ assert dataType is None or isinstance(dataType, DecimalType)
+ return decimal.Decimal(literal.decimal.value)
+ elif literal.HasField("string"):
+ assert dataType is None or isinstance(dataType, StringType)
+ return literal.string
+ elif literal.HasField("date"):
+ assert dataType is None or isinstance(dataType, DataType)
+ return DateType().fromInternal(literal.date)
+ elif literal.HasField("timestamp"):
+ assert dataType is None or isinstance(dataType, TimestampType)
+ return TimestampType().fromInternal(literal.timestamp)
+ elif literal.HasField("timestamp_ntz"):
+ assert dataType is None or isinstance(dataType, TimestampNTZType)
+ return TimestampNTZType().fromInternal(literal.timestamp_ntz)
+ elif literal.HasField("day_time_interval"):
+ assert dataType is None or isinstance(dataType, DayTimeIntervalType)
+ return DayTimeIntervalType().fromInternal(literal.day_time_interval)
+ elif literal.HasField("array"):
+ elementType = proto_schema_to_pyspark_data_type(literal.array.element_type)
+ if dataType is not None:
+ assert isinstance(dataType, ArrayType)
+ assert elementType == dataType.elementType
+ return [LiteralExpression._to_value(v, elementType) for v in literal.array.elements]
+
+ raise TypeError(f"Unsupported Literal Value {literal}")
+
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
"""Converts the literal expression to the literal in proto."""
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py
index f627136650d..129a25098b1 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -42,6 +42,7 @@ if should_test_connect:
IntegerType,
MapType,
ArrayType,
+ DoubleType,
)
@@ -986,6 +987,55 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
self.assertEqual(len(l4.array.elements[1].array.elements), 2)
self.assertEqual(len(l4.array.elements[2].array.elements), 0)
+ def test_literal_to_any_conversion(self):
+ for value in [
+ b"binary\0\0asas",
+ True,
+ False,
+ 0,
+ 12,
+ -1,
+ 0.0,
+ 1.234567,
+ decimal.Decimal(0.0),
+ decimal.Decimal(1.234567),
+ "sss",
+ datetime.date(2022, 12, 13),
+ datetime.datetime.now(),
+ datetime.timedelta(1, 2, 3),
+ [1, 2, 3, 4, 5, 6],
+ [-1.0, 2.0, 3.0],
+ ["x", "y", "z"],
+ [[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]],
+ ]:
+ lit = LiteralExpression._from_value(value)
+ proto_lit = lit.to_plan(None).literal
+ value2 = LiteralExpression._to_value(proto_lit)
+ self.assertEqual(value, value2)
+
+ with self.assertRaises(AssertionError):
+ lit = LiteralExpression._from_value(1.234567)
+ proto_lit = lit.to_plan(None).literal
+ LiteralExpression._to_value(proto_lit, StringType())
+
+ with self.assertRaises(AssertionError):
+ lit = LiteralExpression._from_value("1.234567")
+ proto_lit = lit.to_plan(None).literal
+ LiteralExpression._to_value(proto_lit, DoubleType())
+
+ with self.assertRaises(AssertionError):
+ # build a array<string> proto literal, but with incorrect elements
+ proto_lit = proto.Expression().literal
+ proto_lit.array.element_type.CopyFrom(pyspark_types_to_proto_types(StringType()))
+ proto_lit.array.elements.append(
+ LiteralExpression("string", StringType()).to_plan(None).literal
+ )
+ proto_lit.array.elements.append(
+ LiteralExpression(1.234, DoubleType()).to_plan(None).literal
+ )
+
+ LiteralExpression._to_value(proto_lit, DoubleType)
+
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_plan import * # noqa: F401
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org