You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2024/01/05 22:15:00 UTC

(arrow-datafusion) branch main updated: Convert Binary Operator `StringConcat` to Function for `array_concat`, `array_append` and `array_prepend` (#8636)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e4059a684 Convert Binary Operator `StringConcat` to Function for `array_concat`, `array_append` and `array_prepend` (#8636)
4e4059a684 is described below

commit 4e4059a68455fbc14f04902c76acbcd258b7f2ef
Author: Jay Zhan <ja...@gmail.com>
AuthorDate: Sat Jan 6 06:14:55 2024 +0800

    Convert Binary Operator `StringConcat` to Function for `array_concat`, `array_append` and `array_prepend` (#8636)
    
    * reuse function for string concat
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    
    * remove casting in string concat
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    
    * add test
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    
    * operator to function rewrite
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    
    * fix explain
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    
    * add more test
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    
    * add column cases
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    
    * cleanup
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    
    * presever name
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    
    * Update datafusion/optimizer/src/analyzer/rewrite_expr.rs
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
    
    * rename
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    
    ---------
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
 datafusion/expr/src/type_coercion/binary.rs        |   2 -
 datafusion/optimizer/src/analyzer/mod.rs           |   6 +
 datafusion/optimizer/src/analyzer/rewrite_expr.rs  | 321 +++++++++++++++++++++
 datafusion/physical-expr/src/expressions/binary.rs |  11 +-
 datafusion/sql/src/expr/mod.rs                     |   2 +
 datafusion/sqllogictest/test_files/array.slt       |  39 +++
 datafusion/sqllogictest/test_files/explain.slt     |   1 +
 7 files changed, 371 insertions(+), 11 deletions(-)

diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs
index 1b62c1bc05..6bacc18700 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -667,8 +667,6 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Da
         (LargeUtf8, from_type) | (from_type, LargeUtf8) => {
             string_concat_internal_coercion(from_type, &LargeUtf8)
         }
-        // TODO: cast between array elements (#6558)
-        (List(_), from_type) | (from_type, List(_)) => Some(from_type.to_owned()),
         _ => None,
     })
 }
diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs
index 14d5ddf473..9d47299a56 100644
--- a/datafusion/optimizer/src/analyzer/mod.rs
+++ b/datafusion/optimizer/src/analyzer/mod.rs
@@ -17,6 +17,7 @@
 
 pub mod count_wildcard_rule;
 pub mod inline_table_scan;
+pub mod rewrite_expr;
 pub mod subquery;
 pub mod type_coercion;
 
@@ -37,6 +38,8 @@ use log::debug;
 use std::sync::Arc;
 use std::time::Instant;
 
+use self::rewrite_expr::OperatorToFunction;
+
 /// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make
 /// the plan valid prior to the rest of the DataFusion optimization process.
 ///
@@ -72,6 +75,9 @@ impl Analyzer {
     pub fn new() -> Self {
         let rules: Vec<Arc<dyn AnalyzerRule + Send + Sync>> = vec![
             Arc::new(InlineTableScan::new()),
+            // OperatorToFunction should be run before TypeCoercion, since it rewrite based on the argument types (List or Scalar),
+            // and TypeCoercion may cast the argument types from Scalar to List.
+            Arc::new(OperatorToFunction::new()),
             Arc::new(TypeCoercion::new()),
             Arc::new(CountWildcardRule::new()),
         ];
diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs
new file mode 100644
index 0000000000..8f1c844ed0
--- /dev/null
+++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs
@@ -0,0 +1,321 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Analyzer rule for to replace operators with function calls (e.g `||` to array_concat`)
+
+use std::sync::Arc;
+
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::tree_node::TreeNodeRewriter;
+use datafusion_common::utils::list_ndims;
+use datafusion_common::DFSchema;
+use datafusion_common::DFSchemaRef;
+use datafusion_common::Result;
+use datafusion_expr::expr::ScalarFunction;
+use datafusion_expr::expr_rewriter::rewrite_preserving_name;
+use datafusion_expr::utils::merge_schema;
+use datafusion_expr::BuiltinScalarFunction;
+use datafusion_expr::Operator;
+use datafusion_expr::ScalarFunctionDefinition;
+use datafusion_expr::{BinaryExpr, Expr, LogicalPlan};
+
+use super::AnalyzerRule;
+
+#[derive(Default)]
+pub struct OperatorToFunction {}
+
+impl OperatorToFunction {
+    pub fn new() -> Self {
+        Self {}
+    }
+}
+
+impl AnalyzerRule for OperatorToFunction {
+    fn name(&self) -> &str {
+        "operator_to_function"
+    }
+
+    fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
+        analyze_internal(&plan)
+    }
+}
+
+fn analyze_internal(plan: &LogicalPlan) -> Result<LogicalPlan> {
+    // optimize child plans first
+    let new_inputs = plan
+        .inputs()
+        .iter()
+        .map(|p| analyze_internal(p))
+        .collect::<Result<Vec<_>>>()?;
+
+    // get schema representing all available input fields. This is used for data type
+    // resolution only, so order does not matter here
+    let mut schema = merge_schema(new_inputs.iter().collect());
+
+    if let LogicalPlan::TableScan(ts) = plan {
+        let source_schema =
+            DFSchema::try_from_qualified_schema(&ts.table_name, &ts.source.schema())?;
+        schema.merge(&source_schema);
+    }
+
+    let mut expr_rewrite = OperatorToFunctionRewriter {
+        schema: Arc::new(schema),
+    };
+
+    let new_expr = plan
+        .expressions()
+        .into_iter()
+        .map(|expr| {
+            // ensure names don't change:
+            // https://github.com/apache/arrow-datafusion/issues/3555
+            rewrite_preserving_name(expr, &mut expr_rewrite)
+        })
+        .collect::<Result<Vec<_>>>()?;
+
+    plan.with_new_exprs(new_expr, &new_inputs)
+}
+
+pub(crate) struct OperatorToFunctionRewriter {
+    pub(crate) schema: DFSchemaRef,
+}
+
+impl TreeNodeRewriter for OperatorToFunctionRewriter {
+    type N = Expr;
+
+    fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+        match expr {
+            Expr::BinaryExpr(BinaryExpr {
+                ref left,
+                op,
+                ref right,
+            }) => {
+                if let Some(fun) = rewrite_array_concat_operator_to_func_for_column(
+                    left.as_ref(),
+                    op,
+                    right.as_ref(),
+                    self.schema.as_ref(),
+                )?
+                .or_else(|| {
+                    rewrite_array_concat_operator_to_func(
+                        left.as_ref(),
+                        op,
+                        right.as_ref(),
+                    )
+                }) {
+                    // Convert &Box<Expr> -> Expr
+                    let left = (**left).clone();
+                    let right = (**right).clone();
+                    return Ok(Expr::ScalarFunction(ScalarFunction {
+                        func_def: ScalarFunctionDefinition::BuiltIn(fun),
+                        args: vec![left, right],
+                    }));
+                }
+
+                Ok(expr)
+            }
+            _ => Ok(expr),
+        }
+    }
+}
+
+/// Summary of the logic below:
+///
+/// 1) array || array -> array concat
+///
+/// 2) array || scalar -> array append
+///
+/// 3) scalar || array -> array prepend
+///
+/// 4) (arry concat, array append, array prepend) || array -> array concat
+///
+/// 5) (arry concat, array append, array prepend) || scalar -> array append
+fn rewrite_array_concat_operator_to_func(
+    left: &Expr,
+    op: Operator,
+    right: &Expr,
+) -> Option<BuiltinScalarFunction> {
+    // Convert `Array StringConcat Array` to ScalarFunction::ArrayConcat
+
+    if op != Operator::StringConcat {
+        return None;
+    }
+
+    match (left, right) {
+        // Chain concat operator (a || b) || array,
+        // (arry concat, array append, array prepend) || array -> array concat
+        (
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat),
+                args: _left_args,
+            }),
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+                args: _right_args,
+            }),
+        )
+        | (
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend),
+                args: _left_args,
+            }),
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+                args: _right_args,
+            }),
+        )
+        | (
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend),
+                args: _left_args,
+            }),
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+                args: _right_args,
+            }),
+        ) => Some(BuiltinScalarFunction::ArrayConcat),
+        // Chain concat operator (a || b) || scalar,
+        // (arry concat, array append, array prepend) || scalar -> array append
+        (
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat),
+                args: _left_args,
+            }),
+            _scalar,
+        )
+        | (
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend),
+                args: _left_args,
+            }),
+            _scalar,
+        )
+        | (
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend),
+                args: _left_args,
+            }),
+            _scalar,
+        ) => Some(BuiltinScalarFunction::ArrayAppend),
+        // array || array -> array concat
+        (
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+                args: _left_args,
+            }),
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+                args: _right_args,
+            }),
+        ) => Some(BuiltinScalarFunction::ArrayConcat),
+        // array || scalar -> array append
+        (
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+                args: _left_args,
+            }),
+            _right_scalar,
+        ) => Some(BuiltinScalarFunction::ArrayAppend),
+        // scalar || array -> array prepend
+        (
+            _left_scalar,
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+                args: _right_args,
+            }),
+        ) => Some(BuiltinScalarFunction::ArrayPrepend),
+
+        _ => None,
+    }
+}
+
+/// Summary of the logic below:
+///
+/// 1) (arry concat, array append, array prepend) || column -> (array append, array concat)
+///
+/// 2) column1 || column2 -> (array prepend, array append, array concat)
+fn rewrite_array_concat_operator_to_func_for_column(
+    left: &Expr,
+    op: Operator,
+    right: &Expr,
+    schema: &DFSchema,
+) -> Result<Option<BuiltinScalarFunction>> {
+    if op != Operator::StringConcat {
+        return Ok(None);
+    }
+
+    match (left, right) {
+        // Column cases:
+        // 1) array_prepend/append/concat || column
+        (
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend),
+                args: _left_args,
+            }),
+            Expr::Column(c),
+        )
+        | (
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend),
+                args: _left_args,
+            }),
+            Expr::Column(c),
+        )
+        | (
+            Expr::ScalarFunction(ScalarFunction {
+                func_def:
+                    ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat),
+                args: _left_args,
+            }),
+            Expr::Column(c),
+        ) => {
+            let d = schema.field_from_column(c)?.data_type();
+            let ndim = list_ndims(d);
+            match ndim {
+                0 => Ok(Some(BuiltinScalarFunction::ArrayAppend)),
+                _ => Ok(Some(BuiltinScalarFunction::ArrayConcat)),
+            }
+        }
+        // 2) select column1 || column2
+        (Expr::Column(c1), Expr::Column(c2)) => {
+            let d1 = schema.field_from_column(c1)?.data_type();
+            let d2 = schema.field_from_column(c2)?.data_type();
+            let ndim1 = list_ndims(d1);
+            let ndim2 = list_ndims(d2);
+            match (ndim1, ndim2) {
+                (0, _) => Ok(Some(BuiltinScalarFunction::ArrayPrepend)),
+                (_, 0) => Ok(Some(BuiltinScalarFunction::ArrayAppend)),
+                _ => Ok(Some(BuiltinScalarFunction::ArrayConcat)),
+            }
+        }
+        _ => Ok(None),
+    }
+}
diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs
index c17081398c..8c4078dbce 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -20,9 +20,7 @@ mod kernels;
 use std::hash::{Hash, Hasher};
 use std::{any::Any, sync::Arc};
 
-use crate::array_expressions::{
-    array_append, array_concat, array_has_all, array_prepend,
-};
+use crate::array_expressions::array_has_all;
 use crate::expressions::datum::{apply, apply_cmp};
 use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
 use crate::physical_expr::down_cast_any_ref;
@@ -598,12 +596,7 @@ impl BinaryExpr {
             BitwiseXor => bitwise_xor_dyn(left, right),
             BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
             BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
-            StringConcat => match (left_data_type, right_data_type) {
-                (DataType::List(_), DataType::List(_)) => array_concat(&[left, right]),
-                (DataType::List(_), _) => array_append(&[left, right]),
-                (_, DataType::List(_)) => array_prepend(&[left, right]),
-                _ => binary_string_array_op!(left, right, concat_elements),
-            },
+            StringConcat => binary_string_array_op!(left, right, concat_elements),
             AtArrow => array_has_all(&[left, right]),
             ArrowAt => array_has_all(&[right, left]),
         }
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index 27351e10eb..9fded63af3 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -98,11 +98,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 StackEntry::Operator(op) => {
                     let right = eval_stack.pop().unwrap();
                     let left = eval_stack.pop().unwrap();
+
                     let expr = Expr::BinaryExpr(BinaryExpr::new(
                         Box::new(left),
                         op,
                         Box::new(right),
                     ));
+
                     eval_stack.push(expr);
                 }
             }
diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt
index 083c4ff31b..d864091a85 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -4617,6 +4617,45 @@ select 1 || make_array(2, 3, 4), 1.0 || make_array(2.0, 3.0, 4.0), 'h' || make_a
 ----
 [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o]
 
+# array concatenate operator with scalars #4 (mixed)
+query ?
+select 0 || [1,2,3] || 4 || [5] || [6,7];
+----
+[0, 1, 2, 3, 4, 5, 6, 7]
+
+# array concatenate operator with nd-list #5 (mixed)
+query ?
+select 0 || [1,2,3] || [[4,5]] || [[6,7,8]] || [9,10];
+----
+[[0, 1, 2, 3], [4, 5], [6, 7, 8], [9, 10]]
+
+# array concatenate operator non-valid cases
+## concat 2D with scalar is not valid
+query error
+select 0 || [1,2,3] || [[4,5]] || [[6,7,8]] || [9,10] || 11;
+
+## concat scalar with 2D is not valid
+query error
+select 0 || [[1,2,3]];
+
+# array concatenate operator with column
+
+statement ok
+CREATE TABLE array_concat_operator_table
+AS VALUES
+  (0, [1, 2, 2, 3], 4, [5, 6, 5]),
+  (-1, [4, 5, 6], 7, [8, 1, 1])
+;
+
+query ?
+select column1 || column2 || column3 || column4 from array_concat_operator_table;
+----
+[0, 1, 2, 2, 3, 4, 5, 6, 5]
+[-1, 4, 5, 6, 7, 8, 1, 1]
+
+statement ok
+drop table array_concat_operator_table;
+
 ## array containment operator
 
 # array containment operator with scalars #1 (at arrow)
diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt
index 4583ef319b..2a39e31388 100644
--- a/datafusion/sqllogictest/test_files/explain.slt
+++ b/datafusion/sqllogictest/test_files/explain.slt
@@ -180,6 +180,7 @@ initial_logical_plan
 Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c
 --TableScan: simple_explain_test
 logical_plan after inline_table_scan SAME TEXT AS ABOVE
+logical_plan after operator_to_function SAME TEXT AS ABOVE
 logical_plan after type_coercion SAME TEXT AS ABOVE
 logical_plan after count_wildcard_rule SAME TEXT AS ABOVE
 analyzed_logical_plan SAME TEXT AS ABOVE