You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/11/18 00:03:30 UTC

[spark] branch master updated: [SPARK-37279][PYTHON][SQL] Support DayTimeIntervalType in createDataFrame, collect and Python UDF

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e2e1e42  [SPARK-37279][PYTHON][SQL] Support DayTimeIntervalType in createDataFrame, collect and Python UDF
e2e1e42 is described below

commit e2e1e42cee3a7f8c3ba7a1fe5f4d5f607792f28e
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Thu Nov 18 09:02:02 2021 +0900

    [SPARK-37279][PYTHON][SQL] Support DayTimeIntervalType in createDataFrame, collect and Python UDF
    
    ### What changes were proposed in this pull request?
    
    This PR implements `DayTimeIntervalType` in PySpark's `DataFrame.collect()`, `SparkSession.createDataFrame()` and `functions.udf`.
    This type is mapped to [`datetime.timedelta`](https://docs.python.org/3/library/datetime.html#timedelta-objects).
    
    Arrow code path will be separately implemented at SPARK-37277, and Py4J support will be done at SPARK-37281.
    
    ### Why are the changes needed?
    
    - In order to support `datetime.timedelta` out of the box via PySpark.
    - To seamlessly support ANSI standard types
    
    Semantically [`datetime.timedelta`](https://docs.python.org/3/library/datetime.html#timedelta-objects) is mapped to `DayTimeIntervalType`. Python's timedelta does not support months, years, etc.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, users will be able to use `datetime.timedelta` in PySpark with `DayTimeIntervalType` at `DataFrame.collect()`, `SparkSession.createDataFrame()` and `functions.udf`:
    
    ```python
    >>> import datetime
    >>> df = spark.createDataFrame([(datetime.timedelta(days=1),)])
    >>> df.collect()
    [Row(_1=datetime.timedelta(days=1))]
    ```
    
    ```python
    >>> from pyspark.sql.functions import udf
    >>> df.select(udf(lambda x: x, "interval day to second")("_1")).show()
    +--------------------+
    |        <lambda>(_1)|
    +--------------------+
    |INTERVAL '1 00:00...|
    +--------------------+
    ```
    
    ### How was this patch tested?
    
    Unittests were added, and the
    
    Closes #34614 from HyukjinKwon/SPARK-37277.
    
    Authored-by: Hyukjin Kwon <gu...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/docs/source/reference/pyspark.sql.rst       |  1 +
 python/pyspark/sql/tests/test_types.py             | 78 +++++++++++++++++++++-
 python/pyspark/sql/tests/test_udf.py               | 20 +++++-
 python/pyspark/sql/types.py                        | 76 ++++++++++++++++++++-
 .../sql/execution/python/EvaluatePython.scala      | 13 ++--
 5 files changed, 178 insertions(+), 10 deletions(-)

diff --git a/python/docs/source/reference/pyspark.sql.rst b/python/docs/source/reference/pyspark.sql.rst
index 63a347e..5928b7b 100644
--- a/python/docs/source/reference/pyspark.sql.rst
+++ b/python/docs/source/reference/pyspark.sql.rst
@@ -302,6 +302,7 @@ Data Types
     StructType
     TimestampNTZType
     TimestampType
+    DayTimeIntervalType
 
 
 Observation
diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py
index f009106..135660f 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -35,6 +35,7 @@ from pyspark.sql.types import (
     FloatType,
     DateType,
     TimestampType,
+    DayTimeIntervalType,
     MapType,
     StringType,
     StructType,
@@ -144,6 +145,7 @@ class TypesTests(ReusedSQLTestCase):
             "a",
             datetime.date(1970, 1, 1),
             datetime.datetime(1970, 1, 1, 0, 0),
+            datetime.timedelta(microseconds=123456678),
             1.0,
             array.array("d", [1]),
             [1],
@@ -165,6 +167,7 @@ class TypesTests(ReusedSQLTestCase):
             "string",
             "date",
             "timestamp",
+            "interval day to second",
             "double",
             "array<double>",
             "array<bigint>",
@@ -186,6 +189,7 @@ class TypesTests(ReusedSQLTestCase):
             "a",
             datetime.date(1970, 1, 1),
             datetime.datetime(1970, 1, 1, 0, 0),
+            datetime.timedelta(microseconds=123456678),
             1.0,
             [1.0],
             [1],
@@ -290,7 +294,7 @@ class TypesTests(ReusedSQLTestCase):
         self.assertEqual(df.first(), Row(key=1, value="1"))
 
     def test_apply_schema(self):
-        from datetime import date, datetime
+        from datetime import date, datetime, timedelta
 
         rdd = self.sc.parallelize(
             [
@@ -303,6 +307,7 @@ class TypesTests(ReusedSQLTestCase):
                     1.0,
                     date(2010, 1, 1),
                     datetime(2010, 1, 1, 1, 1, 1),
+                    timedelta(days=1),
                     {"a": 1},
                     (2,),
                     [1, 2, 3],
@@ -320,6 +325,7 @@ class TypesTests(ReusedSQLTestCase):
                 StructField("float1", FloatType(), False),
                 StructField("date1", DateType(), False),
                 StructField("time1", TimestampType(), False),
+                StructField("daytime1", DayTimeIntervalType(), False),
                 StructField("map1", MapType(StringType(), IntegerType(), False), False),
                 StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
                 StructField("list1", ArrayType(ByteType(), False), False),
@@ -337,6 +343,7 @@ class TypesTests(ReusedSQLTestCase):
                 x.float1,
                 x.date1,
                 x.time1,
+                x.daytime1,
                 x.map1["a"],
                 x.struct1.b,
                 x.list1,
@@ -352,6 +359,7 @@ class TypesTests(ReusedSQLTestCase):
             1.0,
             date(2010, 1, 1),
             datetime(2010, 1, 1, 1, 1, 1),
+            timedelta(days=1),
             1,
             2,
             [1, 2, 3],
@@ -929,6 +937,74 @@ class TypesTests(ReusedSQLTestCase):
                 a = array.array(t)
                 self.spark.createDataFrame([Row(myarray=a)]).collect()
 
+    def test_daytime_interval_type_constructor(self):
+        # SPARK-37277: Test constructors in day time interval.
+        self.assertEqual(DayTimeIntervalType().simpleString(), "interval day to second")
+        self.assertEqual(
+            DayTimeIntervalType(DayTimeIntervalType.DAY).simpleString(), "interval day"
+        )
+        self.assertEqual(
+            DayTimeIntervalType(
+                DayTimeIntervalType.HOUR, DayTimeIntervalType.SECOND
+            ).simpleString(),
+            "interval hour to second",
+        )
+
+        with self.assertRaisesRegex(RuntimeError, "interval None to 3 is invalid"):
+            DayTimeIntervalType(endField=DayTimeIntervalType.SECOND)
+
+        with self.assertRaisesRegex(RuntimeError, "interval 123 to 123 is invalid"):
+            DayTimeIntervalType(123)
+
+        with self.assertRaisesRegex(RuntimeError, "interval 0 to 321 is invalid"):
+            DayTimeIntervalType(DayTimeIntervalType.DAY, 321)
+
+    def test_daytime_interval_type(self):
+        # SPARK-37277: Support DayTimeIntervalType in createDataFrame
+        timedetlas = [
+            (datetime.timedelta(microseconds=123),),
+            (
+                datetime.timedelta(
+                    days=1, seconds=23, microseconds=123, milliseconds=4, minutes=5, hours=11
+                ),
+            ),
+            (datetime.timedelta(microseconds=-123),),
+            (datetime.timedelta(days=-1),),
+        ]
+        df = self.spark.createDataFrame(timedetlas, schema="td interval day to second")
+        self.assertEqual(set(r.td for r in df.collect()), set(set(r[0] for r in timedetlas)))
+
+        exprs = [
+            "INTERVAL '1 02:03:04' DAY TO SECOND AS a",
+            "INTERVAL '1 02:03' DAY TO MINUTE AS b",
+            "INTERVAL '1 02' DAY TO HOUR AS c",
+            "INTERVAL '1' DAY AS d",
+            "INTERVAL '26:03:04' HOUR TO SECOND AS e",
+            "INTERVAL '26:03' HOUR TO MINUTE AS f",
+            "INTERVAL '26' HOUR AS g",
+            "INTERVAL '1563:04' MINUTE TO SECOND AS h",
+            "INTERVAL '1563' MINUTE AS i",
+            "INTERVAL '93784' SECOND AS j",
+        ]
+        df = self.spark.range(1).selectExpr(exprs)
+
+        actual = list(df.first())
+        expected = [
+            datetime.timedelta(days=1, hours=2, minutes=3, seconds=4),
+            datetime.timedelta(days=1, hours=2, minutes=3),
+            datetime.timedelta(days=1, hours=2),
+            datetime.timedelta(days=1),
+            datetime.timedelta(hours=26, minutes=3, seconds=4),
+            datetime.timedelta(hours=26, minutes=3),
+            datetime.timedelta(hours=26),
+            datetime.timedelta(minutes=1563, seconds=4),
+            datetime.timedelta(minutes=1563),
+            datetime.timedelta(seconds=93784),
+        ]
+
+        for n, (a, e) in enumerate(zip(actual, expected)):
+            self.assertEqual(a, e, "%s does not match with %s" % (exprs[n], expected[n]))
+
 
 class DataTypeTests(unittest.TestCase):
     # regression test for SPARK-6055
diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py
index 8cb87e7..52d6fa4 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -24,7 +24,7 @@ import datetime
 
 from pyspark import SparkContext
 from pyspark.sql import SparkSession, Column, Row
-from pyspark.sql.functions import udf
+from pyspark.sql.functions import udf, assert_true, lit
 from pyspark.sql.udf import UserDefinedFunction
 from pyspark.sql.types import (
     StringType,
@@ -36,6 +36,7 @@ from pyspark.sql.types import (
     StructType,
     StructField,
     TimestampNTZType,
+    DayTimeIntervalType,
 )
 from pyspark.sql.utils import AnalysisException
 from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
@@ -607,6 +608,23 @@ class UDFTests(ReusedSQLTestCase):
             self.assertEqual(df.schema[0].dataType.typeName(), "timestamp_ntz")
             self.assertEqual(df.first()[0], datetime.datetime(1970, 1, 1, 0, 0))
 
+    def test_udf_daytime_interval(self):
+        # SPARK-37277: Support DayTimeIntervalType in Python UDF
+        @udf(DayTimeIntervalType(DayTimeIntervalType.DAY, DayTimeIntervalType.SECOND))
+        def noop(x):
+            assert x == datetime.timedelta(microseconds=123)
+            return x
+
+        df = self.spark.createDataFrame(
+            [(datetime.timedelta(microseconds=123),)], schema="td interval day to second"
+        ).select(noop("td").alias("td"))
+
+        df.select(
+            assert_true(lit("INTERVAL '0 00:00:00.000123' DAY TO SECOND") == df.td.cast("string"))
+        ).collect()
+        self.assertEqual(df.schema[0].dataType.simpleString(), "interval day to second")
+        self.assertEqual(df.first()[0], datetime.timedelta(microseconds=123))
+
     def test_nonparam_udf_with_aggregate(self):
         import pyspark.sql.functions as f
 
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 9d7b4cf..eda68b8 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -18,6 +18,7 @@
 import sys
 import decimal
 import time
+import math
 import datetime
 import calendar
 import json
@@ -65,6 +66,7 @@ __all__ = [
     "ByteType",
     "IntegerType",
     "LongType",
+    "DayTimeIntervalType",
     "Row",
     "ShortType",
     "ArrayType",
@@ -317,6 +319,65 @@ class LongType(IntegralType):
         return "bigint"
 
 
+class DayTimeIntervalType(AtomicType):
+    """DayTimeIntervalType (datetime.timedelta)."""
+
+    DAY = 0
+    HOUR = 1
+    MINUTE = 2
+    SECOND = 3
+
+    _fields = {
+        DAY: "day",
+        HOUR: "hour",
+        MINUTE: "minute",
+        SECOND: "second",
+    }
+
+    _inverted_fields = dict(zip(_fields.values(), _fields.keys()))
+
+    def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None):
+        if startField is None and endField is None:
+            # Default matched to scala side.
+            startField = DayTimeIntervalType.DAY
+            endField = DayTimeIntervalType.SECOND
+        elif startField is not None and endField is None:
+            endField = startField
+
+        fields = DayTimeIntervalType._fields
+        if startField not in fields.keys() or endField not in fields.keys():
+            raise RuntimeError("interval %s to %s is invalid" % (startField, endField))
+        self.startField = cast(int, startField)
+        self.endField = cast(int, endField)
+
+    def _str_repr(self) -> str:
+        fields = DayTimeIntervalType._fields
+        start_field_name = fields[self.startField]
+        end_field_name = fields[self.endField]
+        if start_field_name == end_field_name:
+            return "interval %s" % start_field_name
+        else:
+            return "interval %s to %s" % (start_field_name, end_field_name)
+
+    simpleString = _str_repr
+
+    jsonValue = _str_repr
+
+    def __repr__(self) -> str:
+        return "%s(%d,%d)" % (type(self).__name__, self.startField, self.endField)
+
+    def needConversion(self) -> bool:
+        return True
+
+    def toInternal(self, dt: datetime.timedelta) -> Optional[int]:
+        if dt is not None:
+            return (math.floor(dt.total_seconds()) * 1000000) + dt.microseconds
+
+    def fromInternal(self, micros: int) -> Optional[datetime.timedelta]:
+        if micros is not None:
+            return datetime.timedelta(microseconds=micros)
+
+
 class ShortType(IntegralType):
     """Short data type, i.e. a signed 16-bit integer."""
 
@@ -905,6 +966,7 @@ _all_complex_types: Dict[str, Type[Union[ArrayType, MapType, StructType]]] = dic
 
 
 _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)")
+_INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?")
 
 
 def _parse_datatype_string(s: str) -> DataType:
@@ -1034,11 +1096,17 @@ def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType:
             return _all_atomic_types[json_value]()
         elif json_value == "decimal":
             return DecimalType()
-        elif json_value == "timestamp_ntz":
-            return TimestampNTZType()
         elif _FIXED_DECIMAL.match(json_value):
             m = _FIXED_DECIMAL.match(json_value)
             return DecimalType(int(m.group(1)), int(m.group(2)))  # type: ignore[union-attr]
+        elif _INTERVAL_DAYTIME.match(json_value):
+            m = _INTERVAL_DAYTIME.match(json_value)
+            inverted_fields = DayTimeIntervalType._inverted_fields
+            first_field = inverted_fields.get(m.group(1))  # type: ignore[union-attr]
+            second_field = inverted_fields.get(m.group(3))  # type: ignore[union-attr]
+            if first_field is not None and second_field is None:
+                return DayTimeIntervalType(first_field)
+            return DayTimeIntervalType(first_field, second_field)
         else:
             raise ValueError("Could not parse datatype: %s" % json_value)
     else:
@@ -1063,6 +1131,7 @@ _type_mappings = {
     datetime.date: DateType,
     datetime.datetime: TimestampType,  # can be TimestampNTZType
     datetime.time: TimestampType,  # can be TimestampNTZType
+    datetime.timedelta: DayTimeIntervalType,
     bytes: BinaryType,
 }
 
@@ -1163,6 +1232,8 @@ def _infer_type(
         return DecimalType(38, 18)
     if dataType is TimestampType and prefer_timestamp_ntz and obj.tzinfo is None:
         return TimestampNTZType()
+    if dataType is DayTimeIntervalType:
+        return DayTimeIntervalType()
     elif dataType is not None:
         return dataType()
 
@@ -1409,6 +1480,7 @@ _acceptable_types = {
     DateType: (datetime.date, datetime.datetime),
     TimestampType: (datetime.datetime,),
     TimestampNTZType: (datetime.datetime,),
+    DayTimeIntervalType: (datetime.timedelta,),
     ArrayType: (list, tuple, array),
     MapType: (dict,),
     StructType: (tuple, list, dict),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index 5d19b93..6664acf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -35,7 +35,7 @@ import org.apache.spark.unsafe.types.UTF8String
 object EvaluatePython {
 
   def needConversionInPython(dt: DataType): Boolean = dt match {
-    case DateType | TimestampType | TimestampNTZType => true
+    case DateType | TimestampType | TimestampNTZType | _: DayTimeIntervalType => true
     case _: StructType => true
     case _: UserDefinedType[_] => true
     case ArrayType(elementType, _) => needConversionInPython(elementType)
@@ -137,11 +137,12 @@ object EvaluatePython {
       case c: Int => c
     }
 
-    case TimestampType | TimestampNTZType => (obj: Any) => nullSafeConvert(obj) {
-      case c: Long => c
-      // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
-      case c: Int => c.toLong
-    }
+    case TimestampType | TimestampNTZType | _: DayTimeIntervalType => (obj: Any) =>
+      nullSafeConvert(obj) {
+        case c: Long => c
+        // Py4J serializes values between MIN_INT and MAX_INT as Ints, not Longs
+        case c: Int => c.toLong
+      }
 
     case StringType => (obj: Any) => nullSafeConvert(obj) {
       case _ => UTF8String.fromString(obj.toString)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org