You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by dh...@apache.org on 2023/09/08 06:40:44 UTC

[arrow-datafusion] branch main updated: Minor: Add `ScalarValue::data_type()` for consistency with other APIs (#7492)

This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 93c209f1b5 Minor: Add `ScalarValue::data_type()` for consistency with other APIs (#7492)
93c209f1b5 is described below

commit 93c209f1b5d0a17b2aa5c6743cbd3cb189a406c8
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Fri Sep 8 02:40:38 2023 -0400

    Minor: Add `ScalarValue::data_type()` for consistency with other APIs (#7492)
    
    * Minor: Add `ScalarValue::data_type()` for consistency
    
    * Use new API in a few places
---
 datafusion/common/src/scalar.rs                              | 11 +++++++++--
 datafusion/expr/src/logical_plan/plan.rs                     |  8 ++++----
 datafusion/optimizer/src/unwrap_cast_in_comparison.rs        | 12 ++++++------
 .../physical-expr/src/aggregate/approx_percentile_cont.rs    |  4 ++--
 datafusion/sql/src/expr/value.rs                             |  4 ++--
 5 files changed, 23 insertions(+), 16 deletions(-)

diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 1f017ead59..28452c6a31 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -1072,8 +1072,8 @@ impl ScalarValue {
         })
     }
 
-    /// Getter for the `DataType` of the value
-    pub fn get_datatype(&self) -> DataType {
+    /// return the [`DataType`] of this `ScalarValue`
+    pub fn data_type(&self) -> DataType {
         match self {
             ScalarValue::Boolean(_) => DataType::Boolean,
             ScalarValue::UInt8(_) => DataType::UInt8,
@@ -1149,6 +1149,13 @@ impl ScalarValue {
         }
     }
 
+    /// Getter for the `DataType` of the value.
+    ///
+    /// Suggest using  [`Self::data_type`] as a more standard API
+    pub fn get_datatype(&self) -> DataType {
+        self.data_type()
+    }
+
     /// Calculate arithmetic negation for a scalar value
     pub fn arithmetic_negate(&self) -> Result<Self> {
         match self {
diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs
index c58ec92174..4e196a7b96 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -939,11 +939,11 @@ impl LogicalPlan {
                 // Verify if the types of the params matches the types of the values
                 let iter = prepare_lp.data_types.iter().zip(param_values.iter());
                 for (i, (param_type, value)) in iter.enumerate() {
-                    if *param_type != value.get_datatype() {
+                    if *param_type != value.data_type() {
                         return plan_err!(
                             "Expected parameter of type {:?}, got {:?} at index {}",
                             param_type,
-                            value.get_datatype(),
+                            value.data_type(),
                             i
                         );
                     }
@@ -1183,11 +1183,11 @@ impl LogicalPlan {
                         ))
                     })?;
                     // check if the data type of the value matches the data type of the placeholder
-                    if Some(value.get_datatype()) != *data_type {
+                    if Some(value.data_type()) != *data_type {
                         return internal_err!(
                             "Placeholder value type mismatch: expected {:?}, got {:?}",
                             data_type,
-                            value.get_datatype()
+                            value.data_type()
                         );
                     }
                     // Replace the placeholder with the value
diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index 963a3dc06f..2e12a283ea 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -298,7 +298,7 @@ fn try_cast_literal_to_type(
     lit_value: &ScalarValue,
     target_type: &DataType,
 ) -> Result<Option<ScalarValue>> {
-    let lit_data_type = lit_value.get_datatype();
+    let lit_data_type = lit_value.data_type();
     // the rule just support the signed numeric data type now
     if !is_support_data_type(&lit_data_type) || !is_support_data_type(target_type) {
         return Ok(None);
@@ -817,7 +817,7 @@ mod tests {
             for s2 in &scalars {
                 let expected_value = ExpectedCast::Value(s2.clone());
 
-                expect_cast(s1.clone(), s2.get_datatype(), expected_value);
+                expect_cast(s1.clone(), s2.data_type(), expected_value);
             }
         }
     }
@@ -842,7 +842,7 @@ mod tests {
             for s2 in &scalars {
                 let expected_value = ExpectedCast::Value(s2.clone());
 
-                expect_cast(s1.clone(), s2.get_datatype(), expected_value);
+                expect_cast(s1.clone(), s2.data_type(), expected_value);
             }
         }
 
@@ -976,10 +976,10 @@ mod tests {
             assert_eq!(lit_tz_none, lit_tz_utc);
 
             // e.g. DataType::Timestamp(_, None)
-            let dt_tz_none = lit_tz_none.get_datatype();
+            let dt_tz_none = lit_tz_none.data_type();
 
             // e.g. DataType::Timestamp(_, Some(utc))
-            let dt_tz_utc = lit_tz_utc.get_datatype();
+            let dt_tz_utc = lit_tz_utc.data_type();
 
             // None <--> None
             expect_cast(
@@ -1102,7 +1102,7 @@ mod tests {
                 if let (
                     DataType::Timestamp(left_unit, left_tz),
                     DataType::Timestamp(right_unit, right_tz),
-                ) = (actual_value.get_datatype(), expected_value.get_datatype())
+                ) = (actual_value.data_type(), expected_value.data_type())
                 {
                     assert_eq!(left_unit, right_unit);
                     assert_eq!(left_tz, right_tz);
diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
index 184cada1dc..aa4749f64a 100644
--- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
@@ -147,7 +147,7 @@ fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
         ScalarValue::Float64(Some(q)) => *q,
         got => return not_impl_err!(
             "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
-            got.get_datatype()
+            got.data_type()
         )
     };
 
@@ -182,7 +182,7 @@ fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
         ScalarValue::Int8(Some(q)) if *q > 0 => *q as usize,
         got => return not_impl_err!(
             "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
-            got.get_datatype()
+            got.data_type()
         )
     };
     Ok(max_size)
diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs
index 158054ce6c..3d5a39a13a 100644
--- a/datafusion/sql/src/expr/value.rs
+++ b/datafusion/sql/src/expr/value.rs
@@ -151,14 +151,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         }
 
         let data_types: HashSet<DataType> =
-            values.iter().map(|e| e.get_datatype()).collect();
+            values.iter().map(|e| e.data_type()).collect();
 
         if data_types.is_empty() {
             Ok(lit(ScalarValue::new_list(None, DataType::Utf8)))
         } else if data_types.len() > 1 {
             not_impl_err!("Arrays with different types are not supported: {data_types:?}")
         } else {
-            let data_type = values[0].get_datatype();
+            let data_type = values[0].data_type();
 
             Ok(lit(ScalarValue::new_list(Some(values), data_type)))
         }