You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/11/18 00:03:30 UTC
[spark] branch master updated: [SPARK-37279][PYTHON][SQL] Support DayTimeIntervalType in createDataFrame, collect and Python UDF
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e2e1e42 [SPARK-37279][PYTHON][SQL] Support DayTimeIntervalType in createDataFrame, collect and Python UDF
e2e1e42 is described below
commit e2e1e42cee3a7f8c3ba7a1fe5f4d5f607792f28e
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Thu Nov 18 09:02:02 2021 +0900
[SPARK-37279][PYTHON][SQL] Support DayTimeIntervalType in createDataFrame, collect and Python UDF
### What changes were proposed in this pull request?
This PR implements `DayTimeIntervalType` in PySpark's `DataFrame.collect()`, `SparkSession.createDataFrame()` and `functions.udf`.
This type is mapped to [`datetime.timedelta`](https://docs.python.org/3/library/datetime.html#timedelta-objects).
Arrow code path will be separately implemented at SPARK-37277, and Py4J support will be done at SPARK-37281.
### Why are the changes needed?
- In order to support `datetime.timedelta` out of the box via PySpark.
- To seamlessly support ANSI standard types
Semantically [`datetime.timedelta`](https://docs.python.org/3/library/datetime.html#timedelta-objects) is mapped to `DayTimeIntervalType`. Python's timedelta does not support months, years, etc.
### Does this PR introduce _any_ user-facing change?
Yes, users will be able to use `datetime.timedelta` in PySpark with `DayTimeIntervalType` at `DataFrame.collect()`, `SparkSession.createDataFrame()` and `functions.udf`:
```python
>>> import datetime
>>> df = spark.createDataFrame([(datetime.timedelta(days=1),)])
>>> df.collect()
[Row(_1=datetime.timedelta(days=1))]
```
```python
>>> from pyspark.sql.functions import udf
>>> df.select(udf(lambda x: x, "interval day to second")("_1")).show()
+--------------------+
| <lambda>(_1)|
+--------------------+
|INTERVAL '1 00:00...|
+--------------------+
```
### How was this patch tested?
Unittests were added, and the
Closes #34614 from HyukjinKwon/SPARK-37277.
Authored-by: Hyukjin Kwon <gu...@apache.org>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/docs/source/reference/pyspark.sql.rst | 1 +
python/pyspark/sql/tests/test_types.py | 78 +++++++++++++++++++++-
python/pyspark/sql/tests/test_udf.py | 20 +++++-
python/pyspark/sql/types.py | 76 ++++++++++++++++++++-
.../sql/execution/python/EvaluatePython.scala | 13 ++--
5 files changed, 178 insertions(+), 10 deletions(-)
diff --git a/python/docs/source/reference/pyspark.sql.rst b/python/docs/source/reference/pyspark.sql.rst
index 63a347e..5928b7b 100644
--- a/python/docs/source/reference/pyspark.sql.rst
+++ b/python/docs/source/reference/pyspark.sql.rst
@@ -302,6 +302,7 @@ Data Types
StructType
TimestampNTZType
TimestampType
+ DayTimeIntervalType
Observation
diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py
index f009106..135660f 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -35,6 +35,7 @@ from pyspark.sql.types import (
FloatType,
DateType,
TimestampType,
+ DayTimeIntervalType,
MapType,
StringType,
StructType,
@@ -144,6 +145,7 @@ class TypesTests(ReusedSQLTestCase):
"a",
datetime.date(1970, 1, 1),
datetime.datetime(1970, 1, 1, 0, 0),
+ datetime.timedelta(microseconds=123456678),
1.0,
array.array("d", [1]),
[1],
@@ -165,6 +167,7 @@ class TypesTests(ReusedSQLTestCase):
"string",
"date",
"timestamp",
+ "interval day to second",
"double",
"array<double>",
"array<bigint>",
@@ -186,6 +189,7 @@ class TypesTests(ReusedSQLTestCase):
"a",
datetime.date(1970, 1, 1),
datetime.datetime(1970, 1, 1, 0, 0),
+ datetime.timedelta(microseconds=123456678),
1.0,
[1.0],
[1],
@@ -290,7 +294,7 @@ class TypesTests(ReusedSQLTestCase):
self.assertEqual(df.first(), Row(key=1, value="1"))
def test_apply_schema(self):
- from datetime import date, datetime
+ from datetime import date, datetime, timedelta
rdd = self.sc.parallelize(
[
@@ -303,6 +307,7 @@ class TypesTests(ReusedSQLTestCase):
1.0,
date(2010, 1, 1),
datetime(2010, 1, 1, 1, 1, 1),
+ timedelta(days=1),
{"a": 1},
(2,),
[1, 2, 3],
@@ -320,6 +325,7 @@ class TypesTests(ReusedSQLTestCase):
StructField("float1", FloatType(), False),
StructField("date1", DateType(), False),
StructField("time1", TimestampType(), False),
+ StructField("daytime1", DayTimeIntervalType(), False),
StructField("map1", MapType(StringType(), IntegerType(), False), False),
StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
StructField("list1", ArrayType(ByteType(), False), False),
@@ -337,6 +343,7 @@ class TypesTests(ReusedSQLTestCase):
x.float1,
x.date1,
x.time1,
+ x.daytime1,
x.map1["a"],
x.struct1.b,
x.list1,
@@ -352,6 +359,7 @@ class TypesTests(ReusedSQLTestCase):
1.0,
date(2010, 1, 1),
datetime(2010, 1, 1, 1, 1, 1),
+ timedelta(days=1),
1,
2,
[1, 2, 3],
@@ -929,6 +937,74 @@ class TypesTests(ReusedSQLTestCase):
a = array.array(t)
self.spark.createDataFrame([Row(myarray=a)]).collect()
+ def test_daytime_interval_type_constructor(self):
+ # SPARK-37277: Test constructors in day time interval.
+ self.assertEqual(DayTimeIntervalType().simpleString(), "interval day to second")
+ self.assertEqual(
+ DayTimeIntervalType(DayTimeIntervalType.DAY).simpleString(), "interval day"
+ )
+ self.assertEqual(
+ DayTimeIntervalType(
+ DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND
+ ).simpleString(),
+ "interval hour to second",
+ )
+
+ with self.assertRaisesRegex(RuntimeError, "interval None to 3 is invalid"):
+ DayTimeIntervalType(endField=DayTimeIntervalType.SECOND)
+
+ with self.assertRaisesRegex(RuntimeError, "interval 123 to 123 is invalid"):
+ DayTimeIntervalType(123)
+
+ with self.assertRaisesRegex(RuntimeError, "interval 0 to 321 is invalid"):
+ DayTimeIntervalType(DayTimeIntervalType.DAY, 321)
+
+ def test_daytime_interval_type(self):
+ # SPARK-37277: Support DayTimeIntervalType in createDataFrame
+ timedetlas = [
+ (datetime.timedelta(microseconds=123),),
+ (
+ datetime.timedelta(
+ days=1, seconds=23, microseconds=123, milliseconds=4, minutes=5, hours=11
+ ),
+ ),
+ (datetime.timedelta(microseconds=-123),),
+ (datetime.timedelta(days=-1),),
+ ]
+ df = self.spark.createDataFrame(timedetlas, schema="td interval day to second")
+ self.assertEqual(set(r.td for r in df.collect()), set(set(r[0] for r in timedetlas)))
+
+ exprs = [
+ "INTERVAL '1 02:03:04' DAY TO SECOND AS a",
+ "INTERVAL '1 02:03' DAY TO MINUTE AS b",
+ "INTERVAL '1 02' DAY TO HOUR AS c",
+ "INTERVAL '1' DAY AS d",
+ "INTERVAL '26:03:04' HOUR TO SECOND AS e",
+ "INTERVAL '26:03' HOUR TO MINUTE AS f",
+ "INTERVAL '26' HOUR AS g",
+ "INTERVAL '1563:04' MINUTE TO SECOND AS h",
+ "INTERVAL '1563' MINUTE AS i",
+ "INTERVAL '93784' SECOND AS j",
+ ]
+ df = self.spark.range(1).selectExpr(exprs)
+
+ actual = list(df.first())
+ expected = [
+ datetime.timedelta(days=1, hours=2, minutes=3, seconds=4),
+ datetime.timedelta(days=1, hours=2, minutes=3),
+ datetime.timedelta(days=1, hours=2),
+ datetime.timedelta(days=1),
+ datetime.timedelta(hours=26, minutes=3, seconds=4),
+ datetime.timedelta(hours=26, minutes=3),
+ datetime.timedelta(hours=26),
+ datetime.timedelta(minutes=1563, seconds=4),
+ datetime.timedelta(minutes=1563),
+ datetime.timedelta(seconds=93784),
+ ]
+
+ for n, (a, e) in enumerate(zip(actual, expected)):
+ self.assertEqual(a, e, "%s does not match with %s" % (exprs[n], expected[n]))
+
class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py
index 8cb87e7..52d6fa4 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -24,7 +24,7 @@ import datetime
from pyspark import SparkContext
from pyspark.sql import SparkSession, Column, Row
-from pyspark.sql.functions import udf
+from pyspark.sql.functions import udf, assert_true, lit
from pyspark.sql.udf import UserDefinedFunction
from pyspark.sql.types import (
StringType,
@@ -36,6 +36,7 @@ from pyspark.sql.types import (
StructType,
StructField,
TimestampNTZType,
+ DayTimeIntervalType,
)
from pyspark.sql.utils import AnalysisException
from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
@@ -607,6 +608,23 @@ class UDFTests(ReusedSQLTestCase):
self.assertEqual(df.schema[0].dataType.typeName(), "timestamp_ntz")
self.assertEqual(df.first()[0], datetime.datetime(1970, 1, 1, 0, 0))
+ def test_udf_daytime_interval(self):
+ # SPARK-37277: Support DayTimeIntervalType in Python UDF
+ @udf(DayTimeIntervalType(DayTimeIntervalType.DAY, DayTimeIntervalType.SECOND))
+ def noop(x):
+ assert x == datetime.timedelta(microseconds=123)
+ return x
+
+ df = self.spark.createDataFrame(
+ [(datetime.timedelta(microseconds=123),)], schema="td interval day to second"
+ ).select(noop("td").alias("td"))
+
+ df.select(
+ assert_true(lit("INTERVAL '0 00:00:00.000123' DAY TO SECOND") == df.td.cast("string"))
+ ).collect()
+ self.assertEqual(df.schema[0].dataType.simpleString(), "interval day to second")
+ self.assertEqual(df.first()[0], datetime.timedelta(microseconds=123))
+
def test_nonparam_udf_with_aggregate(self):
import pyspark.sql.functions as f
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 9d7b4cf..eda68b8 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -18,6 +18,7 @@
import sys
import decimal
import time
+import math
import datetime
import calendar
import json
@@ -65,6 +66,7 @@ __all__ = [
"ByteType",
"IntegerType",
"LongType",
+ "DayTimeIntervalType",
"Row",
"ShortType",
"ArrayType",
@@ -317,6 +319,65 @@ class LongType(IntegralType):
return "bigint"
+class DayTimeIntervalType(AtomicType):
+ """DayTimeIntervalType (datetime.timedelta)."""
+
+ DAY = 0
+ HOUR = 1
+ MINUTE = 2
+ SECOND = 3
+
+ _fields = {
+ DAY: "day",
+ HOUR: "hour",
+ MINUTE: "minute",
+ SECOND: "second",
+ }
+
+ _inverted_fields = dict(zip(_fields.values(), _fields.keys()))
+
+ def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None):
+ if startField is None and endField is None:
+ # Default matched to scala side.
+ startField = DayTimeIntervalType.DAY
+ endField = DayTimeIntervalType.SECOND
+ elif startField is not None and endField is None:
+ endField = startField
+
+ fields = DayTimeIntervalType._fields
+ if startField not in fields.keys() or endField not in fields.keys():
+ raise RuntimeError("interval %s to %s is invalid" % (startField, endField))
+ self.startField = cast(int, startField)
+ self.endField = cast(int, endField)
+
+ def _str_repr(self) -> str:
+ fields = DayTimeIntervalType._fields
+ start_field_name = fields[self.startField]
+ end_field_name = fields[self.endField]
+ if start_field_name == end_field_name:
+ return "interval %s" % start_field_name
+ else:
+ return "interval %s to %s" % (start_field_name, end_field_name)
+
+ simpleString = _str_repr
+
+ jsonValue = _str_repr
+
+ def __repr__(self) -> str:
+ return "%s(%d,%d)" % (type(self).__name__, self.startField, self.endField)
+
+ def needConversion(self) -> bool:
+ return True
+
+ def toInternal(self, dt: datetime.timedelta) -> Optional[int]:
+ if dt is not None:
+ return (math.floor(dt.total_seconds()) * 1000000) + dt.microseconds
+
+ def fromInternal(self, micros: int) -> Optional[datetime.timedelta]:
+ if micros is not None:
+ return datetime.timedelta(microseconds=micros)
+
+
class ShortType(IntegralType):
"""Short data type, i.e. a signed 16-bit integer."""
@@ -905,6 +966,7 @@ _all_complex_types: Dict[str, Type[Union[ArrayType, MapType, StructType]]] = dic
_FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)")
+_INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?")
def _parse_datatype_string(s: str) -> DataType:
@@ -1034,11 +1096,17 @@ def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType:
return _all_atomic_types[json_value]()
elif json_value == "decimal":
return DecimalType()
- elif json_value == "timestamp_ntz":
- return TimestampNTZType()
elif _FIXED_DECIMAL.match(json_value):
m = _FIXED_DECIMAL.match(json_value)
return DecimalType(int(m.group(1)), int(m.group(2))) # type: ignore[union-attr]
+ elif _INTERVAL_DAYTIME.match(json_value):
+ m = _INTERVAL_DAYTIME.match(json_value)
+ inverted_fields = DayTimeIntervalType._inverted_fields
+ first_field = inverted_fields.get(m.group(1)) # type: ignore[union-attr]
+ second_field = inverted_fields.get(m.group(3)) # type: ignore[union-attr]
+ if first_field is not None and second_field is None:
+ return DayTimeIntervalType(first_field)
+ return DayTimeIntervalType(first_field, second_field)
else:
raise ValueError("Could not parse datatype: %s" % json_value)
else:
@@ -1063,6 +1131,7 @@ _type_mappings = {
datetime.date: DateType,
datetime.datetime: TimestampType, # can be TimestampNTZType
datetime.time: TimestampType, # can be TimestampNTZType
+ datetime.timedelta: DayTimeIntervalType,
bytes: BinaryType,
}
@@ -1163,6 +1232,8 @@ def _infer_type(
return DecimalType(38, 18)
if dataType is TimestampType and prefer_timestamp_ntz and obj.tzinfo is None:
return TimestampNTZType()
+ if dataType is DayTimeIntervalType:
+ return DayTimeIntervalType()
elif dataType is not None:
return dataType()
@@ -1409,6 +1480,7 @@ _acceptable_types = {
DateType: (datetime.date, datetime.datetime),
TimestampType: (datetime.datetime,),
TimestampNTZType: (datetime.datetime,),
+ DayTimeIntervalType: (datetime.timedelta,),
ArrayType: (list, tuple, array),
MapType: (dict,),
StructType: (tuple, list, dict),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index 5d19b93..6664acf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -35,7 +35,7 @@ import org.apache.spark.unsafe.types.UTF8String
object EvaluatePython {
def needConversionInPython(dt: DataType): Boolean = dt match {
- case DateType | TimestampType | TimestampNTZType => true
+ case DateType | TimestampType | TimestampNTZType | _: DayTimeIntervalType => true
case _: StructType => true
case _: UserDefinedType[_] => true
case ArrayType(elementType, _) => needConversionInPython(elementType)
@@ -137,11 +137,12 @@ object EvaluatePython {
case c: Int => c
}
- case TimestampType | TimestampNTZType => (obj: Any) => nullSafeConvert(obj) {
- case c: Long => c
- // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
- case c: Int => c.toLong
- }
+ case TimestampType | TimestampNTZType | _: DayTimeIntervalType => (obj: Any) =>
+ nullSafeConvert(obj) {
+ case c: Long => c
+ // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
+ case c: Int => c.toLong
+ }
case StringType => (obj: Any) => nullSafeConvert(obj) {
case _ => UTF8String.fromString(obj.toString)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org