You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/06/27 20:20:37 UTC

[arrow-datafusion] branch main updated: Make 'date_trunc' returns the same type as its input (#6654)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ac8f6908ba Make 'date_trunc' returns the same type as its input (#6654)
ac8f6908ba is described below

commit ac8f6908ba31b869e0eb9d14617879bd0330f6d0
Author: Alex Huang <hu...@gmail.com>
AuthorDate: Tue Jun 27 22:20:33 2023 +0200

    Make 'date_trunc' returns the same type as its input (#6654)
    
    * fix: 'datatrunc' return inconsistent type
    
    * fix: test error
    
    * change return value
    
    * fix error
    
    * remove comment
    
    * update date_trunc
    
    * update date_trunc
    
    * remove unwrap
    
    * truncate timestamp to second in date_trunc_single
---
 datafusion/core/tests/sql/timestamp.rs             |   6 +-
 .../tests/sqllogictests/test_files/timestamps.slt  | 104 ++++++++++++--
 datafusion/expr/src/built_in_function.rs           |  29 ++--
 .../physical-expr/src/datetime_expressions.rs      | 159 ++++++++++++++++-----
 4 files changed, 239 insertions(+), 59 deletions(-)

diff --git a/datafusion/core/tests/sql/timestamp.rs b/datafusion/core/tests/sql/timestamp.rs
index bb4d54a61c..df922844bb 100644
--- a/datafusion/core/tests/sql/timestamp.rs
+++ b/datafusion/core/tests/sql/timestamp.rs
@@ -712,7 +712,7 @@ async fn test_arrow_typeof() -> Result<()> {
         "+--------------------------------------------------------------------------+",
         "| arrow_typeof(date_trunc(Utf8(\"minute\"),to_timestamp_seconds(Int64(61)))) |",
         "+--------------------------------------------------------------------------+",
-        "| Timestamp(Nanosecond, None)                                              |",
+        "| Timestamp(Second, None)                                                  |",
         "+--------------------------------------------------------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
@@ -723,7 +723,7 @@ async fn test_arrow_typeof() -> Result<()> {
         "+-------------------------------------------------------------------------+",
         "| arrow_typeof(date_trunc(Utf8(\"second\"),to_timestamp_millis(Int64(61)))) |",
         "+-------------------------------------------------------------------------+",
-        "| Timestamp(Nanosecond, None)                                             |",
+        "| Timestamp(Millisecond, None)                                            |",
         "+-------------------------------------------------------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
@@ -734,7 +734,7 @@ async fn test_arrow_typeof() -> Result<()> {
         "+------------------------------------------------------------------------------+",
         "| arrow_typeof(date_trunc(Utf8(\"millisecond\"),to_timestamp_micros(Int64(61)))) |",
         "+------------------------------------------------------------------------------+",
-        "| Timestamp(Nanosecond, None)                                                  |",
+        "| Timestamp(Microsecond, None)                                                 |",
         "+------------------------------------------------------------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
diff --git a/datafusion/core/tests/sqllogictests/test_files/timestamps.slt b/datafusion/core/tests/sqllogictests/test_files/timestamps.slt
index 5eae354401..3ba7c38f16 100644
--- a/datafusion/core/tests/sqllogictests/test_files/timestamps.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/timestamps.slt
@@ -29,9 +29,9 @@
 
 statement ok
 create table ts_data(ts bigint, value int) as values
-  (1599572549190855000, 1),
-  (1599568949190855000, 2),
-  (1599565349190855000, 3);
+  (1599572549190855123, 1),
+  (1599568949190855123, 2),
+  (1599565349190855123, 3);
 
 statement ok
 create table ts_data_nanos as select arrow_cast(ts, 'Timestamp(Nanosecond, None)') as ts, value from ts_data;
@@ -270,9 +270,9 @@ SELECT COUNT(*) FROM ts_data_secs where ts > from_unixtime(1599566400)
 query P rowsort
 SELECT DISTINCT ts FROM ts_data_nanos;
 ----
-2020-09-08T11:42:29.190855
-2020-09-08T12:42:29.190855
-2020-09-08T13:42:29.190855
+2020-09-08T11:42:29.190855123
+2020-09-08T12:42:29.190855123
+2020-09-08T13:42:29.190855123
 
 
 query I
@@ -1010,6 +1010,96 @@ ts_data_secs 2020-09-08T00:00:00
 ts_data_secs 2020-09-08T00:00:00
 ts_data_secs 2020-09-08T00:00:00
 
+# Test date trun on different granularity
+query TP rowsort
+SELECT 'millisecond', DATE_TRUNC('millisecond', ts) FROM ts_data_nanos
+  UNION ALL
+SELECT 'microsecond', DATE_TRUNC('microsecond', ts) FROM ts_data_nanos
+  UNION ALL
+SELECT 'second', DATE_TRUNC('second', ts) FROM ts_data_nanos
+  UNION ALL
+SELECT 'minute', DATE_TRUNC('minute', ts) FROM ts_data_nanos
+----
+microsecond 2020-09-08T11:42:29.190855
+microsecond 2020-09-08T12:42:29.190855
+microsecond 2020-09-08T13:42:29.190855
+millisecond 2020-09-08T11:42:29.190
+millisecond 2020-09-08T12:42:29.190
+millisecond 2020-09-08T13:42:29.190
+minute 2020-09-08T11:42:00
+minute 2020-09-08T12:42:00
+minute 2020-09-08T13:42:00
+second 2020-09-08T11:42:29
+second 2020-09-08T12:42:29
+second 2020-09-08T13:42:29
+
+query TP rowsort
+SELECT 'millisecond', DATE_TRUNC('millisecond', ts) FROM ts_data_micros
+  UNION ALL
+SELECT 'microsecond', DATE_TRUNC('microsecond', ts) FROM ts_data_micros
+  UNION ALL
+SELECT 'second', DATE_TRUNC('second', ts) FROM ts_data_micros
+  UNION ALL
+SELECT 'minute', DATE_TRUNC('minute', ts) FROM ts_data_micros
+----
+microsecond 2020-09-08T11:42:29.190855
+microsecond 2020-09-08T12:42:29.190855
+microsecond 2020-09-08T13:42:29.190855
+millisecond 2020-09-08T11:42:29.190
+millisecond 2020-09-08T12:42:29.190
+millisecond 2020-09-08T13:42:29.190
+minute 2020-09-08T11:42:00
+minute 2020-09-08T12:42:00
+minute 2020-09-08T13:42:00
+second 2020-09-08T11:42:29
+second 2020-09-08T12:42:29
+second 2020-09-08T13:42:29
+
+query TP rowsort
+SELECT 'millisecond', DATE_TRUNC('millisecond', ts) FROM ts_data_millis
+  UNION ALL
+SELECT 'microsecond', DATE_TRUNC('microsecond', ts) FROM ts_data_millis
+  UNION ALL
+SELECT 'second', DATE_TRUNC('second', ts) FROM ts_data_millis
+  UNION ALL
+SELECT 'minute', DATE_TRUNC('minute', ts) FROM ts_data_millis
+----
+microsecond 2020-09-08T11:42:29.190
+microsecond 2020-09-08T12:42:29.190
+microsecond 2020-09-08T13:42:29.190
+millisecond 2020-09-08T11:42:29.190
+millisecond 2020-09-08T12:42:29.190
+millisecond 2020-09-08T13:42:29.190
+minute 2020-09-08T11:42:00
+minute 2020-09-08T12:42:00
+minute 2020-09-08T13:42:00
+second 2020-09-08T11:42:29
+second 2020-09-08T12:42:29
+second 2020-09-08T13:42:29
+
+query TP rowsort
+SELECT 'millisecond', DATE_TRUNC('millisecond', ts) FROM ts_data_secs
+  UNION ALL
+SELECT 'microsecond', DATE_TRUNC('microsecond', ts) FROM ts_data_secs
+  UNION ALL
+SELECT 'second', DATE_TRUNC('second', ts) FROM ts_data_secs
+  UNION ALL
+SELECT 'minute', DATE_TRUNC('minute', ts) FROM ts_data_secs
+----
+microsecond 2020-09-08T11:42:29
+microsecond 2020-09-08T12:42:29
+microsecond 2020-09-08T13:42:29
+millisecond 2020-09-08T11:42:29
+millisecond 2020-09-08T12:42:29
+millisecond 2020-09-08T13:42:29
+minute 2020-09-08T11:42:00
+minute 2020-09-08T12:42:00
+minute 2020-09-08T13:42:00
+second 2020-09-08T11:42:29
+second 2020-09-08T12:42:29
+second 2020-09-08T13:42:29
+
+
 # test date trunc on different timestamp scalar types and ensure they are consistent
 query P rowsort
 SELECT DATE_TRUNC('second', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp(Second, None)')) as ts
@@ -1026,8 +1116,6 @@ SELECT DATE_TRUNC('day', arrow_cast(TIMESTAMP '2023-08-03 14:38:50Z', 'Timestamp
 2023-08-03T14:38:50
 
 
-
-
 # Demonstrate that strings are automatically coerced to timestamps (don't use TIMESTAMP)
 
 query P
diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs
index 677743f5b3..ed1d9147d7 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -552,17 +552,19 @@ impl BuiltinScalarFunction {
             BuiltinScalarFunction::Concat => Ok(Utf8),
             BuiltinScalarFunction::ConcatWithSeparator => Ok(Utf8),
             BuiltinScalarFunction::DatePart => Ok(Float64),
-            // DateTrunc always makes nanosecond timestamps
-            BuiltinScalarFunction::DateTrunc => Ok(Timestamp(Nanosecond, None)),
-            BuiltinScalarFunction::DateBin => match input_expr_types[1] {
-                Timestamp(Nanosecond, _) | Utf8 => Ok(Timestamp(Nanosecond, None)),
-                Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, None)),
-                Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, None)),
-                Timestamp(Second, _) => Ok(Timestamp(Second, None)),
-                _ => Err(DataFusionError::Internal(format!(
+            BuiltinScalarFunction::DateBin | BuiltinScalarFunction::DateTrunc => {
+                match input_expr_types[1] {
+                    Timestamp(Nanosecond, _) | Utf8 | Null => {
+                        Ok(Timestamp(Nanosecond, None))
+                    }
+                    Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, None)),
+                    Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, None)),
+                    Timestamp(Second, _) => Ok(Timestamp(Second, None)),
+                    _ => Err(DataFusionError::Internal(format!(
                     "The {self} function can only accept timestamp as the second arg."
                 ))),
-            },
+                }
+            }
             BuiltinScalarFunction::InitCap => {
                 utf8_to_str_type(&input_expr_types[0], "initcap")
             }
@@ -889,8 +891,13 @@ impl BuiltinScalarFunction {
                 ],
                 self.volatility(),
             ),
-            BuiltinScalarFunction::DateTrunc => Signature::exact(
-                vec![Utf8, Timestamp(Nanosecond, None)],
+            BuiltinScalarFunction::DateTrunc => Signature::one_of(
+                vec![
+                    Exact(vec![Utf8, Timestamp(Nanosecond, None)]),
+                    Exact(vec![Utf8, Timestamp(Microsecond, None)]),
+                    Exact(vec![Utf8, Timestamp(Millisecond, None)]),
+                    Exact(vec![Utf8, Timestamp(Second, None)]),
+                ],
                 self.volatility(),
             ),
             BuiltinScalarFunction::DateBin => {
diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs
index 443514063f..3a36e8a489 100644
--- a/datafusion/physical-expr/src/datetime_expressions.rs
+++ b/datafusion/physical-expr/src/datetime_expressions.rs
@@ -215,13 +215,17 @@ fn quarter_month(date: &NaiveDateTime) -> u32 {
 }
 
 fn date_trunc_single(granularity: &str, value: i64) -> Result<i64> {
+    if granularity == "millisecond" || granularity == "microsecond" {
+        return Ok(value);
+    }
+
     let value = timestamp_ns_to_datetime(value)
         .ok_or_else(|| {
             DataFusionError::Execution(format!("Timestamp {value} out of range"))
         })?
         .with_nanosecond(0);
     let value = match granularity {
-        "second" | "millisecond" | "microsecond" => value,
+        "second" => value,
         "minute" => value.and_then(|d| d.with_second(0)),
         "hour" => value
             .and_then(|d| d.with_second(0))
@@ -262,6 +266,55 @@ fn date_trunc_single(granularity: &str, value: i64) -> Result<i64> {
     Ok(value.unwrap().timestamp_nanos())
 }
 
+fn _date_trunc(
+    tu: TimeUnit,
+    value: &Option<i64>,
+    granularity: &str,
+    f: impl Fn(Option<i64>) -> Result<Option<i64>>,
+) -> Result<Option<i64>, DataFusionError> {
+    let scale = match tu {
+        TimeUnit::Second => 1_000_000_000,
+        TimeUnit::Millisecond => 1_000_000,
+        TimeUnit::Microsecond => 1_000,
+        TimeUnit::Nanosecond => 1,
+    };
+
+    let Some(value) = value else {
+        return Ok(None);
+    };
+
+    // convert to nanoseconds
+    let Some(nano) = (f)(Some(value * scale))? else {
+        return Ok(None);
+    };
+
+    let result = match tu {
+        TimeUnit::Second => match granularity {
+            "minute" => Some(nano / 1_000_000_000 / 60 * 60),
+            _ => Some(nano / 1_000_000_000),
+        },
+        TimeUnit::Millisecond => match granularity {
+            "minute" => Some(nano / 1_000_000 / 1_000 / 60 * 1_000 * 60),
+            "second" => Some(nano / 1_000_000 / 1_000 * 1_000),
+            _ => Some(nano / 1_000_000),
+        },
+        TimeUnit::Microsecond => match granularity {
+            "minute" => Some(nano / 1_000 / 1_000_000 / 60 * 60 * 1_000_000),
+            "second" => Some(nano / 1_000 / 1_000_000 * 1_000_000),
+            "millisecond" => Some(nano / 1_000 / 1_000 * 1_000),
+            _ => Some(nano / 1_000),
+        },
+        _ => match granularity {
+            "minute" => Some(nano / 1_000_000_000 / 60 * 1_000_000_000 * 60),
+            "second" => Some(nano / 1_000_000_000 * 1_000_000_000),
+            "millisecond" => Some(nano / 1_000_000 * 1_000_000),
+            "microsecond" => Some(nano / 1_000 * 1_000),
+            _ => Some(nano),
+        },
+    };
+    Ok(result)
+}
+
 /// date_trunc SQL function
 pub fn date_trunc(args: &[ColumnarValue]) -> Result<ColumnarValue> {
     let (granularity, array) = (&args[0], &args[1]);
@@ -282,49 +335,81 @@ pub fn date_trunc(args: &[ColumnarValue]) -> Result<ColumnarValue> {
 
     Ok(match array {
         ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => {
-            let nano = (f)(*v)?;
-
-            match granularity.as_str() {
-                "minute" => {
-                    // trunc to minute
-                    let second = ScalarValue::TimestampNanosecond(
-                        nano.map(|nano| nano / 1_000_000_000 * 1_000_000_000),
-                        tz_opt.clone(),
-                    );
-                    ColumnarValue::Scalar(second)
+            let value = _date_trunc(TimeUnit::Nanosecond, v, granularity.as_str(), f)?;
+            let value = ScalarValue::TimestampNanosecond(value, tz_opt.clone());
+            ColumnarValue::Scalar(value)
+        }
+        ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => {
+            let value = _date_trunc(TimeUnit::Microsecond, v, granularity.as_str(), f)?;
+            let value = ScalarValue::TimestampMicrosecond(value, tz_opt.clone());
+            ColumnarValue::Scalar(value)
+        }
+        ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => {
+            let value = _date_trunc(TimeUnit::Millisecond, v, granularity.as_str(), f)?;
+            let value = ScalarValue::TimestampMillisecond(value, tz_opt.clone());
+            ColumnarValue::Scalar(value)
+        }
+        ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => {
+            let value = _date_trunc(TimeUnit::Second, v, granularity.as_str(), f)?;
+            let value = ScalarValue::TimestampSecond(value, tz_opt.clone());
+            ColumnarValue::Scalar(value)
+        }
+        ColumnarValue::Array(array) => {
+            let array_type = array.data_type();
+            match array_type {
+                DataType::Timestamp(TimeUnit::Second, _) => {
+                    let array = as_timestamp_second_array(array)?;
+                    let array = array
+                        .iter()
+                        .map(|x| {
+                            _date_trunc(TimeUnit::Second, &x, granularity.as_str(), f)
+                        })
+                        .collect::<Result<TimestampSecondArray>>()?;
+                    ColumnarValue::Array(Arc::new(array))
                 }
-                "second" => {
-                    // trunc to second
-                    let mill = ScalarValue::TimestampNanosecond(
-                        nano.map(|nano| nano / 1_000_000 * 1_000_000),
-                        tz_opt.clone(),
-                    );
-                    ColumnarValue::Scalar(mill)
+                DataType::Timestamp(TimeUnit::Millisecond, _) => {
+                    let array = as_timestamp_millisecond_array(array)?;
+                    let array = array
+                        .iter()
+                        .map(|x| {
+                            _date_trunc(
+                                TimeUnit::Millisecond,
+                                &x,
+                                granularity.as_str(),
+                                f,
+                            )
+                        })
+                        .collect::<Result<TimestampMillisecondArray>>()?;
+                    ColumnarValue::Array(Arc::new(array))
                 }
-                "millisecond" => {
-                    // trunc to microsecond
-                    let micro = ScalarValue::TimestampNanosecond(
-                        nano.map(|nano| nano / 1_000 * 1_000),
-                        tz_opt.clone(),
-                    );
-                    ColumnarValue::Scalar(micro)
+                DataType::Timestamp(TimeUnit::Microsecond, _) => {
+                    let array = as_timestamp_microsecond_array(array)?;
+                    let array = array
+                        .iter()
+                        .map(|x| {
+                            _date_trunc(
+                                TimeUnit::Microsecond,
+                                &x,
+                                granularity.as_str(),
+                                f,
+                            )
+                        })
+                        .collect::<Result<TimestampMicrosecondArray>>()?;
+                    ColumnarValue::Array(Arc::new(array))
                 }
                 _ => {
-                    // trunc to nanosecond
-                    let nano = ScalarValue::TimestampNanosecond(nano, tz_opt.clone());
-                    ColumnarValue::Scalar(nano)
+                    let array = as_timestamp_nanosecond_array(array)?;
+                    let array = array
+                        .iter()
+                        .map(|x| {
+                            _date_trunc(TimeUnit::Nanosecond, &x, granularity.as_str(), f)
+                        })
+                        .collect::<Result<TimestampNanosecondArray>>()?;
+
+                    ColumnarValue::Array(Arc::new(array))
                 }
             }
         }
-        ColumnarValue::Array(array) => {
-            let array = as_timestamp_nanosecond_array(array)?;
-            let array = array
-                .iter()
-                .map(f)
-                .collect::<Result<TimestampNanosecondArray>>()?;
-
-            ColumnarValue::Array(Arc::new(array))
-        }
         _ => {
             return Err(DataFusionError::Execution(
                 "second argument of `date_trunc` must be nanosecond timestamp scalar or array".to_string(),