You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by dh...@apache.org on 2023/06/07 07:45:39 UTC

[arrow-datafusion] branch main updated: refactor: remove type_coercion in PhysicalExpr. (#6575)

This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 12b88ea9b3 refactor: remove type_coercion in PhysicalExpr. (#6575)
12b88ea9b3 is described below

commit 12b88ea9b38351ba6a1596defea3ac5ec57b4378
Author: jakevin <ja...@gmail.com>
AuthorDate: Wed Jun 7 15:45:30 2023 +0800

    refactor: remove type_coercion in PhysicalExpr. (#6575)
---
 datafusion/core/src/physical_plan/mod.rs         |   4 +-
 datafusion/core/src/physical_plan/windows/mod.rs |  35 ++--
 datafusion/physical-expr/src/functions.rs        |  32 +++-
 datafusion/physical-expr/src/lib.rs              |   1 -
 datafusion/physical-expr/src/sort_expr.rs        |   2 +-
 datafusion/physical-expr/src/type_coercion.rs    | 201 -----------------------
 6 files changed, 44 insertions(+), 231 deletions(-)

diff --git a/datafusion/core/src/physical_plan/mod.rs b/datafusion/core/src/physical_plan/mod.rs
index ce5dc041b9..2346452517 100644
--- a/datafusion/core/src/physical_plan/mod.rs
+++ b/datafusion/core/src/physical_plan/mod.rs
@@ -703,9 +703,7 @@ use crate::physical_plan::common::AbortOnDropSingle;
 use crate::physical_plan::repartition::RepartitionExec;
 use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
 use datafusion_execution::TaskContext;
-pub use datafusion_physical_expr::{
-    expressions, functions, hash_utils, type_coercion, udf,
-};
+pub use datafusion_physical_expr::{expressions, functions, hash_utils, udf};
 
 #[cfg(test)]
 mod tests {
diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs
index a17b8ba87e..b4b2d6bc64 100644
--- a/datafusion/core/src/physical_plan/windows/mod.rs
+++ b/datafusion/core/src/physical_plan/windows/mod.rs
@@ -24,14 +24,13 @@ use crate::physical_plan::{
         cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile,
         PhysicalSortExpr, RowNumber,
     },
-    type_coercion::coerce,
     udaf, ExecutionPlan, PhysicalExpr,
 };
 use crate::scalar::ScalarValue;
 use arrow::datatypes::Schema;
 use arrow_schema::{SchemaRef, SortOptions};
 use datafusion_expr::{
-    window_function::{signature_for_built_in, BuiltInWindowFunction, WindowFunction},
+    window_function::{BuiltInWindowFunction, WindowFunction},
     WindowFrame,
 };
 use datafusion_physical_expr::window::{
@@ -133,8 +132,7 @@ fn create_built_in_window_expr(
         BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name)),
         BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name)),
         BuiltInWindowFunction::Ntile => {
-            let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
-            let n: i64 = get_scalar_value_from_args(&coerced_args, 0)?
+            let n: i64 = get_scalar_value_from_args(args, 0)?
                 .ok_or_else(|| {
                     DataFusionError::Execution(
                         "NTILE requires at least 1 argument".to_string(),
@@ -145,33 +143,26 @@ fn create_built_in_window_expr(
             Arc::new(Ntile::new(name, n))
         }
         BuiltInWindowFunction::Lag => {
-            let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
-            let arg = coerced_args[0].clone();
+            let arg = args[0].clone();
             let data_type = args[0].data_type(input_schema)?;
-            let shift_offset = get_scalar_value_from_args(&coerced_args, 1)?
+            let shift_offset = get_scalar_value_from_args(args, 1)?
                 .map(|v| v.try_into())
                 .and_then(|v| v.ok());
-            let default_value = get_scalar_value_from_args(&coerced_args, 2)?;
+            let default_value = get_scalar_value_from_args(args, 2)?;
             Arc::new(lag(name, data_type, arg, shift_offset, default_value))
         }
         BuiltInWindowFunction::Lead => {
-            let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
-            let arg = coerced_args[0].clone();
+            let arg = args[0].clone();
             let data_type = args[0].data_type(input_schema)?;
-            let shift_offset = get_scalar_value_from_args(&coerced_args, 1)?
+            let shift_offset = get_scalar_value_from_args(args, 1)?
                 .map(|v| v.try_into())
                 .and_then(|v| v.ok());
-            let default_value = get_scalar_value_from_args(&coerced_args, 2)?;
+            let default_value = get_scalar_value_from_args(args, 2)?;
             Arc::new(lead(name, data_type, arg, shift_offset, default_value))
         }
         BuiltInWindowFunction::NthValue => {
-            let coerced_args = coerce(args, input_schema, &signature_for_built_in(fun))?;
-            let arg = coerced_args[0].clone();
-            let n = coerced_args[1]
-                .as_any()
-                .downcast_ref::<Literal>()
-                .unwrap()
-                .value();
+            let arg = args[0].clone();
+            let n = args[1].as_any().downcast_ref::<Literal>().unwrap().value();
             let n: i64 = n
                 .clone()
                 .try_into()
@@ -181,14 +172,12 @@ fn create_built_in_window_expr(
             Arc::new(NthValue::nth(name, arg, data_type, n)?)
         }
         BuiltInWindowFunction::FirstValue => {
-            let arg =
-                coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone();
+            let arg = args[0].clone();
             let data_type = args[0].data_type(input_schema)?;
             Arc::new(NthValue::first(name, arg, data_type))
         }
         BuiltInWindowFunction::LastValue => {
-            let arg =
-                coerce(args, input_schema, &signature_for_built_in(fun))?[0].clone();
+            let arg = args[0].clone();
             let data_type = args[0].data_type(input_schema)?;
             Arc::new(NthValue::last(name, arg, data_type))
         }
diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs
index 15728854ff..e8b262c761 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -809,9 +809,9 @@ pub fn create_physical_fun(
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::expressions::try_cast;
     use crate::expressions::{col, lit};
     use crate::from_slice::FromSlice;
-    use crate::type_coercion::coerce;
     use arrow::{
         array::{
             Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array,
@@ -822,6 +822,8 @@ mod tests {
     };
     use datafusion_common::cast::as_uint64_array;
     use datafusion_common::{Result, ScalarValue};
+    use datafusion_expr::type_coercion::functions::data_types;
+    use datafusion_expr::Signature;
 
     /// $FUNC function to test
     /// $ARGS arguments (vec) to pass to function
@@ -2885,7 +2887,33 @@ mod tests {
         Ok(())
     }
 
-    // Helper function
+    // Helper function just for testing.
+    // Returns `expressions` coerced to types compatible with
+    // `signature`, if possible.
+    pub fn coerce(
+        expressions: &[Arc<dyn PhysicalExpr>],
+        schema: &Schema,
+        signature: &Signature,
+    ) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
+        if expressions.is_empty() {
+            return Ok(vec![]);
+        }
+
+        let current_types = expressions
+            .iter()
+            .map(|e| e.data_type(schema))
+            .collect::<Result<Vec<_>>>()?;
+
+        let new_types = data_types(&current_types, signature)?;
+
+        expressions
+            .iter()
+            .enumerate()
+            .map(|(i, expr)| try_cast(expr.clone(), schema, new_types[i].clone()))
+            .collect::<Result<Vec<_>>>()
+    }
+
+    // Helper function just for testing.
     // The type coercion will be done in the logical phase, should do the type coercion for the test
     fn create_physical_expr_with_type_coercion(
         fun: &BuiltinScalarFunction,
diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs
index 710e9342b1..21a88b5d89 100644
--- a/datafusion/physical-expr/src/lib.rs
+++ b/datafusion/physical-expr/src/lib.rs
@@ -37,7 +37,6 @@ mod sort_expr;
 pub mod string_expressions;
 pub mod struct_expressions;
 pub mod tree_node;
-pub mod type_coercion;
 pub mod udf;
 #[cfg(feature = "unicode_expressions")]
 pub mod unicode_expressions;
diff --git a/datafusion/physical-expr/src/sort_expr.rs b/datafusion/physical-expr/src/sort_expr.rs
index dc93b67fa6..df519551d8 100644
--- a/datafusion/physical-expr/src/sort_expr.rs
+++ b/datafusion/physical-expr/src/sort_expr.rs
@@ -80,7 +80,7 @@ impl PhysicalSortExpr {
 
 /// Represents sort requirement associated with a plan
 ///
-/// If the requirement incudes [`SortOptions`] then both the
+/// If the requirement includes [`SortOptions`] then both the
 /// expression *and* the sort options must match.
 ///
 /// If the requirement does not include [`SortOptions`]) then only the
diff --git a/datafusion/physical-expr/src/type_coercion.rs b/datafusion/physical-expr/src/type_coercion.rs
deleted file mode 100644
index 399dcc0899..0000000000
--- a/datafusion/physical-expr/src/type_coercion.rs
+++ /dev/null
@@ -1,201 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Type coercion rules for functions with multiple valid signatures
-//!
-//! Coercion is performed automatically by DataFusion when the types
-//! of arguments passed to a function do not exactly match the types
-//! required by that function. In this case, DataFusion will attempt to
-//! *coerce* the arguments to types accepted by the function by
-//! inserting CAST operations.
-//!
-//! CAST operations added by coercion are lossless and never discard
-//! information. For example coercion from i32 -> i64 might be
-//! performed because all valid i32 values can be represented using an
-//! i64. However, i64 -> i32 is never performed as there are i64
-//! values which can not be represented by i32 values.
-
-use super::PhysicalExpr;
-use crate::expressions::try_cast;
-use arrow::datatypes::Schema;
-use datafusion_common::Result;
-use datafusion_expr::{type_coercion::functions::data_types, Signature};
-use std::{sync::Arc, vec};
-
-/// Returns `expressions` coerced to types compatible with
-/// `signature`, if possible.
-///
-/// See the module level documentation for more detail on coercion.
-pub fn coerce(
-    expressions: &[Arc<dyn PhysicalExpr>],
-    schema: &Schema,
-    signature: &Signature,
-) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
-    if expressions.is_empty() {
-        return Ok(vec![]);
-    }
-
-    let current_types = expressions
-        .iter()
-        .map(|e| e.data_type(schema))
-        .collect::<Result<Vec<_>>>()?;
-
-    let new_types = data_types(&current_types, signature)?;
-
-    expressions
-        .iter()
-        .enumerate()
-        .map(|(i, expr)| try_cast(expr.clone(), schema, new_types[i].clone()))
-        .collect::<Result<Vec<_>>>()
-}
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use crate::expressions::col;
-    use arrow::datatypes::{DataType, Field, Schema};
-    use arrow_schema::Fields;
-    use datafusion_common::DataFusionError;
-    use datafusion_expr::Volatility;
-
-    #[test]
-    fn test_coerce() -> Result<()> {
-        // create a schema
-        let schema = |t: Vec<DataType>| {
-            Schema::new(
-                t.iter()
-                    .enumerate()
-                    .map(|(i, t)| Field::new(format!("c{i}"), t.clone(), true))
-                    .collect::<Fields>(),
-            )
-        };
-
-        // create a vector of expressions
-        let expressions = |t: Vec<DataType>, schema| -> Result<Vec<_>> {
-            t.iter()
-                .enumerate()
-                .map(|(i, t)| {
-                    try_cast(col(&format!("c{i}"), &schema)?, &schema, t.clone())
-                })
-                .collect::<Result<Vec<_>>>()
-        };
-
-        // create a case: input + expected result
-        let case =
-            |observed: Vec<DataType>, valid, expected: Vec<DataType>| -> Result<_> {
-                let schema = schema(observed.clone());
-                let expr = expressions(observed, schema.clone())?;
-                let expected = expressions(expected, schema.clone())?;
-                Ok((expr.clone(), schema, valid, expected))
-            };
-
-        let cases = vec![
-            // u16 -> u32
-            case(
-                vec![DataType::UInt16],
-                Signature::uniform(1, vec![DataType::UInt32], Volatility::Immutable),
-                vec![DataType::UInt32],
-            )?,
-            // same type
-            case(
-                vec![DataType::UInt32, DataType::UInt32],
-                Signature::uniform(2, vec![DataType::UInt32], Volatility::Immutable),
-                vec![DataType::UInt32, DataType::UInt32],
-            )?,
-            case(
-                vec![DataType::UInt32],
-                Signature::uniform(
-                    1,
-                    vec![DataType::Float32, DataType::Float64],
-                    Volatility::Immutable,
-                ),
-                vec![DataType::Float32],
-            )?,
-            // u32 -> f32
-            case(
-                vec![DataType::UInt32, DataType::UInt32],
-                Signature::variadic(vec![DataType::Float32], Volatility::Immutable),
-                vec![DataType::Float32, DataType::Float32],
-            )?,
-            // u32 -> f32
-            case(
-                vec![DataType::Float32, DataType::UInt32],
-                Signature::variadic_equal(Volatility::Immutable),
-                vec![DataType::Float32, DataType::Float32],
-            )?,
-            // common type is u64
-            case(
-                vec![DataType::UInt32, DataType::UInt64],
-                Signature::variadic(
-                    vec![DataType::UInt32, DataType::UInt64],
-                    Volatility::Immutable,
-                ),
-                vec![DataType::UInt64, DataType::UInt64],
-            )?,
-            // f32 -> f32
-            case(
-                vec![DataType::Float32],
-                Signature::any(1, Volatility::Immutable),
-                vec![DataType::Float32],
-            )?,
-        ];
-
-        for case in cases {
-            let observed = format!("{:?}", coerce(&case.0, &case.1, &case.2)?);
-            let expected = format!("{:?}", case.3);
-            assert_eq!(observed, expected);
-        }
-
-        // now cases that are expected to fail
-        let cases = vec![
-            // we do not know how to cast bool to UInt16 => fail
-            case(
-                vec![DataType::Boolean],
-                Signature::uniform(1, vec![DataType::UInt16], Volatility::Immutable),
-                vec![],
-            )?,
-            // u32 and bool are not uniform
-            case(
-                vec![DataType::UInt32, DataType::Boolean],
-                Signature::variadic_equal(Volatility::Immutable),
-                vec![],
-            )?,
-            // bool is not castable to u32
-            case(
-                vec![DataType::Boolean, DataType::Boolean],
-                Signature::variadic(vec![DataType::UInt32], Volatility::Immutable),
-                vec![],
-            )?,
-            // expected two arguments
-            case(
-                vec![DataType::UInt32],
-                Signature::any(2, Volatility::Immutable),
-                vec![],
-            )?,
-        ];
-
-        for case in cases {
-            if coerce(&case.0, &case.1, &case.2).is_ok() {
-                return Err(DataFusionError::Plan(format!(
-                    "Error was expected in {case:?}"
-                )));
-            }
-        }
-
-        Ok(())
-    }
-}