You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/12/08 17:13:35 UTC

(arrow-datafusion) branch main updated: Minor: refactor `data_trunc` to reduce duplicated code (#8430)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 047fb33368 Minor: refactor `data_trunc` to reduce duplicated code (#8430)
047fb33368 is described below

commit 047fb333683b2fbbc3da227480a5a4a8625038aa
Author: Alex Huang <hu...@gmail.com>
AuthorDate: Fri Dec 8 18:13:30 2023 +0100

    Minor: refactor `data_trunc` to reduce duplicated code (#8430)
    
    * refactor data_trunc
    
    * fix cast to timestamp array
    
    * fix cast to timestamp scalar
    
    * fix doc
---
 datafusion/common/src/scalar.rs                    |  15 +++
 .../physical-expr/src/datetime_expressions.rs      | 137 ++++++---------------
 2 files changed, 53 insertions(+), 99 deletions(-)

diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 7e18c313e0..d730fbf89b 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -46,6 +46,7 @@ use arrow::{
     },
 };
 use arrow_array::cast::as_list_array;
+use arrow_array::types::ArrowTimestampType;
 use arrow_array::{ArrowNativeTypeOp, Scalar};
 
 /// A dynamically typed, nullable single value, (the single-valued counter-part
@@ -774,6 +775,20 @@ impl ScalarValue {
         ScalarValue::IntervalMonthDayNano(Some(val))
     }
 
+    /// Returns a [`ScalarValue`] representing
+    /// `value` and `tz_opt` timezone
+    pub fn new_timestamp<T: ArrowTimestampType>(
+        value: Option<i64>,
+        tz_opt: Option<Arc<str>>,
+    ) -> Self {
+        match T::UNIT {
+            TimeUnit::Second => ScalarValue::TimestampSecond(value, tz_opt),
+            TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(value, tz_opt),
+            TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(value, tz_opt),
+            TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(value, tz_opt),
+        }
+    }
+
     /// Create a zero value in the given type.
     pub fn new_zero(datatype: &DataType) -> Result<ScalarValue> {
         assert!(datatype.is_primitive());
diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs
index 04cfec29ea..d634b4d019 100644
--- a/datafusion/physical-expr/src/datetime_expressions.rs
+++ b/datafusion/physical-expr/src/datetime_expressions.rs
@@ -36,6 +36,7 @@ use arrow::{
         TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
     },
 };
+use arrow_array::types::ArrowTimestampType;
 use arrow_array::{
     timezone::Tz, TimestampMicrosecondArray, TimestampMillisecondArray,
     TimestampSecondArray,
@@ -43,7 +44,7 @@ use arrow_array::{
 use chrono::prelude::*;
 use chrono::{Duration, Months, NaiveDate};
 use datafusion_common::cast::{
-    as_date32_array, as_date64_array, as_generic_string_array,
+    as_date32_array, as_date64_array, as_generic_string_array, as_primitive_array,
     as_timestamp_microsecond_array, as_timestamp_millisecond_array,
     as_timestamp_nanosecond_array, as_timestamp_second_array,
 };
@@ -335,7 +336,7 @@ fn date_trunc_coarse(granularity: &str, value: i64, tz: Option<Tz>) -> Result<i6
 }
 
 // truncates a single value with the given timeunit to the specified granularity
-fn _date_trunc(
+fn general_date_trunc(
     tu: TimeUnit,
     value: &Option<i64>,
     tz: Option<Tz>,
@@ -403,123 +404,61 @@ pub fn date_trunc(args: &[ColumnarValue]) -> Result<ColumnarValue> {
             return exec_err!("Granularity of `date_trunc` must be non-null scalar Utf8");
         };
 
+    fn process_array<T: ArrowTimestampType>(
+        array: &dyn Array,
+        granularity: String,
+        tz_opt: &Option<Arc<str>>,
+    ) -> Result<ColumnarValue> {
+        let parsed_tz = parse_tz(tz_opt)?;
+        let array = as_primitive_array::<T>(array)?;
+        let array = array
+            .iter()
+            .map(|x| general_date_trunc(T::UNIT, &x, parsed_tz, granularity.as_str()))
+            .collect::<Result<PrimitiveArray<T>>>()?
+            .with_timezone_opt(tz_opt.clone());
+        Ok(ColumnarValue::Array(Arc::new(array)))
+    }
+
+    fn process_scalr<T: ArrowTimestampType>(
+        v: &Option<i64>,
+        granularity: String,
+        tz_opt: &Option<Arc<str>>,
+    ) -> Result<ColumnarValue> {
+        let parsed_tz = parse_tz(tz_opt)?;
+        let value = general_date_trunc(T::UNIT, v, parsed_tz, granularity.as_str())?;
+        let value = ScalarValue::new_timestamp::<T>(value, tz_opt.clone());
+        Ok(ColumnarValue::Scalar(value))
+    }
+
     Ok(match array {
         ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => {
-            let parsed_tz = parse_tz(tz_opt)?;
-            let value =
-                _date_trunc(TimeUnit::Nanosecond, v, parsed_tz, granularity.as_str())?;
-            let value = ScalarValue::TimestampNanosecond(value, tz_opt.clone());
-            ColumnarValue::Scalar(value)
+            process_scalr::<TimestampNanosecondType>(v, granularity, tz_opt)?
         }
         ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => {
-            let parsed_tz = parse_tz(tz_opt)?;
-            let value =
-                _date_trunc(TimeUnit::Microsecond, v, parsed_tz, granularity.as_str())?;
-            let value = ScalarValue::TimestampMicrosecond(value, tz_opt.clone());
-            ColumnarValue::Scalar(value)
+            process_scalr::<TimestampMicrosecondType>(v, granularity, tz_opt)?
         }
         ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => {
-            let parsed_tz = parse_tz(tz_opt)?;
-            let value =
-                _date_trunc(TimeUnit::Millisecond, v, parsed_tz, granularity.as_str())?;
-            let value = ScalarValue::TimestampMillisecond(value, tz_opt.clone());
-            ColumnarValue::Scalar(value)
+            process_scalr::<TimestampMillisecondType>(v, granularity, tz_opt)?
         }
         ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => {
-            let parsed_tz = parse_tz(tz_opt)?;
-            let value =
-                _date_trunc(TimeUnit::Second, v, parsed_tz, granularity.as_str())?;
-            let value = ScalarValue::TimestampSecond(value, tz_opt.clone());
-            ColumnarValue::Scalar(value)
+            process_scalr::<TimestampSecondType>(v, granularity, tz_opt)?
         }
         ColumnarValue::Array(array) => {
             let array_type = array.data_type();
             match array_type {
                 DataType::Timestamp(TimeUnit::Second, tz_opt) => {
-                    let parsed_tz = parse_tz(tz_opt)?;
-                    let array = as_timestamp_second_array(array)?;
-                    let array = array
-                        .iter()
-                        .map(|x| {
-                            _date_trunc(
-                                TimeUnit::Second,
-                                &x,
-                                parsed_tz,
-                                granularity.as_str(),
-                            )
-                        })
-                        .collect::<Result<TimestampSecondArray>>()?
-                        .with_timezone_opt(tz_opt.clone());
-                    ColumnarValue::Array(Arc::new(array))
+                    process_array::<TimestampSecondType>(array, granularity, tz_opt)?
                 }
                 DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => {
-                    let parsed_tz = parse_tz(tz_opt)?;
-                    let array = as_timestamp_millisecond_array(array)?;
-                    let array = array
-                        .iter()
-                        .map(|x| {
-                            _date_trunc(
-                                TimeUnit::Millisecond,
-                                &x,
-                                parsed_tz,
-                                granularity.as_str(),
-                            )
-                        })
-                        .collect::<Result<TimestampMillisecondArray>>()?
-                        .with_timezone_opt(tz_opt.clone());
-                    ColumnarValue::Array(Arc::new(array))
+                    process_array::<TimestampMillisecondType>(array, granularity, tz_opt)?
                 }
                 DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => {
-                    let parsed_tz = parse_tz(tz_opt)?;
-                    let array = as_timestamp_microsecond_array(array)?;
-                    let array = array
-                        .iter()
-                        .map(|x| {
-                            _date_trunc(
-                                TimeUnit::Microsecond,
-                                &x,
-                                parsed_tz,
-                                granularity.as_str(),
-                            )
-                        })
-                        .collect::<Result<TimestampMicrosecondArray>>()?
-                        .with_timezone_opt(tz_opt.clone());
-                    ColumnarValue::Array(Arc::new(array))
+                    process_array::<TimestampMicrosecondType>(array, granularity, tz_opt)?
                 }
                 DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => {
-                    let parsed_tz = parse_tz(tz_opt)?;
-                    let array = as_timestamp_nanosecond_array(array)?;
-                    let array = array
-                        .iter()
-                        .map(|x| {
-                            _date_trunc(
-                                TimeUnit::Nanosecond,
-                                &x,
-                                parsed_tz,
-                                granularity.as_str(),
-                            )
-                        })
-                        .collect::<Result<TimestampNanosecondArray>>()?
-                        .with_timezone_opt(tz_opt.clone());
-                    ColumnarValue::Array(Arc::new(array))
-                }
-                _ => {
-                    let parsed_tz = None;
-                    let array = as_timestamp_nanosecond_array(array)?;
-                    let array = array
-                        .iter()
-                        .map(|x| {
-                            _date_trunc(
-                                TimeUnit::Nanosecond,
-                                &x,
-                                parsed_tz,
-                                granularity.as_str(),
-                            )
-                        })
-                        .collect::<Result<TimestampNanosecondArray>>()?;
-
-                    ColumnarValue::Array(Arc::new(array))
+                    process_array::<TimestampNanosecondType>(array, granularity, tz_opt)?
                 }
+                _ => process_array::<TimestampNanosecondType>(array, granularity, &None)?,
             }
         }
         _ => {