You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2024/01/05 22:15:00 UTC
(arrow-datafusion) branch main updated: Convert Binary Operator `StringConcat` to Function for `array_concat`, `array_append` and `array_prepend` (#8636)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 4e4059a684 Convert Binary Operator `StringConcat` to Function for `array_concat`, `array_append` and `array_prepend` (#8636)
4e4059a684 is described below
commit 4e4059a68455fbc14f04902c76acbcd258b7f2ef
Author: Jay Zhan <ja...@gmail.com>
AuthorDate: Sat Jan 6 06:14:55 2024 +0800
Convert Binary Operator `StringConcat` to Function for `array_concat`, `array_append` and `array_prepend` (#8636)
* reuse function for string concat
Signed-off-by: jayzhan211 <ja...@gmail.com>
* remove casting in string concat
Signed-off-by: jayzhan211 <ja...@gmail.com>
* add test
Signed-off-by: jayzhan211 <ja...@gmail.com>
* operator to function rewrite
Signed-off-by: jayzhan211 <ja...@gmail.com>
* fix explain
Signed-off-by: jayzhan211 <ja...@gmail.com>
* add more test
Signed-off-by: jayzhan211 <ja...@gmail.com>
* add column cases
Signed-off-by: jayzhan211 <ja...@gmail.com>
* cleanup
Signed-off-by: jayzhan211 <ja...@gmail.com>
* presever name
Signed-off-by: jayzhan211 <ja...@gmail.com>
* Update datafusion/optimizer/src/analyzer/rewrite_expr.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* rename
Signed-off-by: jayzhan211 <ja...@gmail.com>
---------
Signed-off-by: jayzhan211 <ja...@gmail.com>
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
datafusion/expr/src/type_coercion/binary.rs | 2 -
datafusion/optimizer/src/analyzer/mod.rs | 6 +
datafusion/optimizer/src/analyzer/rewrite_expr.rs | 321 +++++++++++++++++++++
datafusion/physical-expr/src/expressions/binary.rs | 11 +-
datafusion/sql/src/expr/mod.rs | 2 +
datafusion/sqllogictest/test_files/array.slt | 39 +++
datafusion/sqllogictest/test_files/explain.slt | 1 +
7 files changed, 371 insertions(+), 11 deletions(-)
diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs
index 1b62c1bc05..6bacc18700 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -667,8 +667,6 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Da
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
string_concat_internal_coercion(from_type, &LargeUtf8)
}
- // TODO: cast between array elements (#6558)
- (List(_), from_type) | (from_type, List(_)) => Some(from_type.to_owned()),
_ => None,
})
}
diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs
index 14d5ddf473..9d47299a56 100644
--- a/datafusion/optimizer/src/analyzer/mod.rs
+++ b/datafusion/optimizer/src/analyzer/mod.rs
@@ -17,6 +17,7 @@
pub mod count_wildcard_rule;
pub mod inline_table_scan;
+pub mod rewrite_expr;
pub mod subquery;
pub mod type_coercion;
@@ -37,6 +38,8 @@ use log::debug;
use std::sync::Arc;
use std::time::Instant;
+use self::rewrite_expr::OperatorToFunction;
+
/// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make
/// the plan valid prior to the rest of the DataFusion optimization process.
///
@@ -72,6 +75,9 @@ impl Analyzer {
pub fn new() -> Self {
let rules: Vec<Arc<dyn AnalyzerRule + Send + Sync>> = vec![
Arc::new(InlineTableScan::new()),
+ // OperatorToFunction should be run before TypeCoercion, since it rewrite based on the argument types (List or Scalar),
+ // and TypeCoercion may cast the argument types from Scalar to List.
+ Arc::new(OperatorToFunction::new()),
Arc::new(TypeCoercion::new()),
Arc::new(CountWildcardRule::new()),
];
diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs
new file mode 100644
index 0000000000..8f1c844ed0
--- /dev/null
+++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs
@@ -0,0 +1,321 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Analyzer rule for to replace operators with function calls (e.g `||` to array_concat`)
+
+use std::sync::Arc;
+
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::tree_node::TreeNodeRewriter;
+use datafusion_common::utils::list_ndims;
+use datafusion_common::DFSchema;
+use datafusion_common::DFSchemaRef;
+use datafusion_common::Result;
+use datafusion_expr::expr::ScalarFunction;
+use datafusion_expr::expr_rewriter::rewrite_preserving_name;
+use datafusion_expr::utils::merge_schema;
+use datafusion_expr::BuiltinScalarFunction;
+use datafusion_expr::Operator;
+use datafusion_expr::ScalarFunctionDefinition;
+use datafusion_expr::{BinaryExpr, Expr, LogicalPlan};
+
+use super::AnalyzerRule;
+
+#[derive(Default)]
+pub struct OperatorToFunction {}
+
+impl OperatorToFunction {
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+impl AnalyzerRule for OperatorToFunction {
+ fn name(&self) -> &str {
+ "operator_to_function"
+ }
+
+ fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
+ analyze_internal(&plan)
+ }
+}
+
+fn analyze_internal(plan: &LogicalPlan) -> Result<LogicalPlan> {
+ // optimize child plans first
+ let new_inputs = plan
+ .inputs()
+ .iter()
+ .map(|p| analyze_internal(p))
+ .collect::<Result<Vec<_>>>()?;
+
+ // get schema representing all available input fields. This is used for data type
+ // resolution only, so order does not matter here
+ let mut schema = merge_schema(new_inputs.iter().collect());
+
+ if let LogicalPlan::TableScan(ts) = plan {
+ let source_schema =
+ DFSchema::try_from_qualified_schema(&ts.table_name, &ts.source.schema())?;
+ schema.merge(&source_schema);
+ }
+
+ let mut expr_rewrite = OperatorToFunctionRewriter {
+ schema: Arc::new(schema),
+ };
+
+ let new_expr = plan
+ .expressions()
+ .into_iter()
+ .map(|expr| {
+ // ensure names don't change:
+ // https://github.com/apache/arrow-datafusion/issues/3555
+ rewrite_preserving_name(expr, &mut expr_rewrite)
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ plan.with_new_exprs(new_expr, &new_inputs)
+}
+
+pub(crate) struct OperatorToFunctionRewriter {
+ pub(crate) schema: DFSchemaRef,
+}
+
+impl TreeNodeRewriter for OperatorToFunctionRewriter {
+ type N = Expr;
+
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ match expr {
+ Expr::BinaryExpr(BinaryExpr {
+ ref left,
+ op,
+ ref right,
+ }) => {
+ if let Some(fun) = rewrite_array_concat_operator_to_func_for_column(
+ left.as_ref(),
+ op,
+ right.as_ref(),
+ self.schema.as_ref(),
+ )?
+ .or_else(|| {
+ rewrite_array_concat_operator_to_func(
+ left.as_ref(),
+ op,
+ right.as_ref(),
+ )
+ }) {
+ // Convert &Box<Expr> -> Expr
+ let left = (**left).clone();
+ let right = (**right).clone();
+ return Ok(Expr::ScalarFunction(ScalarFunction {
+ func_def: ScalarFunctionDefinition::BuiltIn(fun),
+ args: vec![left, right],
+ }));
+ }
+
+ Ok(expr)
+ }
+ _ => Ok(expr),
+ }
+ }
+}
+
+/// Summary of the logic below:
+///
+/// 1) array || array -> array concat
+///
+/// 2) array || scalar -> array append
+///
+/// 3) scalar || array -> array prepend
+///
+/// 4) (arry concat, array append, array prepend) || array -> array concat
+///
+/// 5) (arry concat, array append, array prepend) || scalar -> array append
+fn rewrite_array_concat_operator_to_func(
+ left: &Expr,
+ op: Operator,
+ right: &Expr,
+) -> Option<BuiltinScalarFunction> {
+ // Convert `Array StringConcat Array` to ScalarFunction::ArrayConcat
+
+ if op != Operator::StringConcat {
+ return None;
+ }
+
+ match (left, right) {
+ // Chain concat operator (a || b) || array,
+ // (arry concat, array append, array prepend) || array -> array concat
+ (
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat),
+ args: _left_args,
+ }),
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+ args: _right_args,
+ }),
+ )
+ | (
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend),
+ args: _left_args,
+ }),
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+ args: _right_args,
+ }),
+ )
+ | (
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend),
+ args: _left_args,
+ }),
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+ args: _right_args,
+ }),
+ ) => Some(BuiltinScalarFunction::ArrayConcat),
+ // Chain concat operator (a || b) || scalar,
+ // (arry concat, array append, array prepend) || scalar -> array append
+ (
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat),
+ args: _left_args,
+ }),
+ _scalar,
+ )
+ | (
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend),
+ args: _left_args,
+ }),
+ _scalar,
+ )
+ | (
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend),
+ args: _left_args,
+ }),
+ _scalar,
+ ) => Some(BuiltinScalarFunction::ArrayAppend),
+ // array || array -> array concat
+ (
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+ args: _left_args,
+ }),
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+ args: _right_args,
+ }),
+ ) => Some(BuiltinScalarFunction::ArrayConcat),
+ // array || scalar -> array append
+ (
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+ args: _left_args,
+ }),
+ _right_scalar,
+ ) => Some(BuiltinScalarFunction::ArrayAppend),
+ // scalar || array -> array prepend
+ (
+ _left_scalar,
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::MakeArray),
+ args: _right_args,
+ }),
+ ) => Some(BuiltinScalarFunction::ArrayPrepend),
+
+ _ => None,
+ }
+}
+
+/// Summary of the logic below:
+///
+/// 1) (arry concat, array append, array prepend) || column -> (array append, array concat)
+///
+/// 2) column1 || column2 -> (array prepend, array append, array concat)
+fn rewrite_array_concat_operator_to_func_for_column(
+ left: &Expr,
+ op: Operator,
+ right: &Expr,
+ schema: &DFSchema,
+) -> Result<Option<BuiltinScalarFunction>> {
+ if op != Operator::StringConcat {
+ return Ok(None);
+ }
+
+ match (left, right) {
+ // Column cases:
+ // 1) array_prepend/append/concat || column
+ (
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayPrepend),
+ args: _left_args,
+ }),
+ Expr::Column(c),
+ )
+ | (
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayAppend),
+ args: _left_args,
+ }),
+ Expr::Column(c),
+ )
+ | (
+ Expr::ScalarFunction(ScalarFunction {
+ func_def:
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::ArrayConcat),
+ args: _left_args,
+ }),
+ Expr::Column(c),
+ ) => {
+ let d = schema.field_from_column(c)?.data_type();
+ let ndim = list_ndims(d);
+ match ndim {
+ 0 => Ok(Some(BuiltinScalarFunction::ArrayAppend)),
+ _ => Ok(Some(BuiltinScalarFunction::ArrayConcat)),
+ }
+ }
+ // 2) select column1 || column2
+ (Expr::Column(c1), Expr::Column(c2)) => {
+ let d1 = schema.field_from_column(c1)?.data_type();
+ let d2 = schema.field_from_column(c2)?.data_type();
+ let ndim1 = list_ndims(d1);
+ let ndim2 = list_ndims(d2);
+ match (ndim1, ndim2) {
+ (0, _) => Ok(Some(BuiltinScalarFunction::ArrayPrepend)),
+ (_, 0) => Ok(Some(BuiltinScalarFunction::ArrayAppend)),
+ _ => Ok(Some(BuiltinScalarFunction::ArrayConcat)),
+ }
+ }
+ _ => Ok(None),
+ }
+}
diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs
index c17081398c..8c4078dbce 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -20,9 +20,7 @@ mod kernels;
use std::hash::{Hash, Hasher};
use std::{any::Any, sync::Arc};
-use crate::array_expressions::{
- array_append, array_concat, array_has_all, array_prepend,
-};
+use crate::array_expressions::array_has_all;
use crate::expressions::datum::{apply, apply_cmp};
use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
use crate::physical_expr::down_cast_any_ref;
@@ -598,12 +596,7 @@ impl BinaryExpr {
BitwiseXor => bitwise_xor_dyn(left, right),
BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
- StringConcat => match (left_data_type, right_data_type) {
- (DataType::List(_), DataType::List(_)) => array_concat(&[left, right]),
- (DataType::List(_), _) => array_append(&[left, right]),
- (_, DataType::List(_)) => array_prepend(&[left, right]),
- _ => binary_string_array_op!(left, right, concat_elements),
- },
+ StringConcat => binary_string_array_op!(left, right, concat_elements),
AtArrow => array_has_all(&[left, right]),
ArrowAt => array_has_all(&[right, left]),
}
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index 27351e10eb..9fded63af3 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -98,11 +98,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
StackEntry::Operator(op) => {
let right = eval_stack.pop().unwrap();
let left = eval_stack.pop().unwrap();
+
let expr = Expr::BinaryExpr(BinaryExpr::new(
Box::new(left),
op,
Box::new(right),
));
+
eval_stack.push(expr);
}
}
diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt
index 083c4ff31b..d864091a85 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -4617,6 +4617,45 @@ select 1 || make_array(2, 3, 4), 1.0 || make_array(2.0, 3.0, 4.0), 'h' || make_a
----
[1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o]
+# array concatenate operator with scalars #4 (mixed)
+query ?
+select 0 || [1,2,3] || 4 || [5] || [6,7];
+----
+[0, 1, 2, 3, 4, 5, 6, 7]
+
+# array concatenate operator with nd-list #5 (mixed)
+query ?
+select 0 || [1,2,3] || [[4,5]] || [[6,7,8]] || [9,10];
+----
+[[0, 1, 2, 3], [4, 5], [6, 7, 8], [9, 10]]
+
+# array concatenate operator non-valid cases
+## concat 2D with scalar is not valid
+query error
+select 0 || [1,2,3] || [[4,5]] || [[6,7,8]] || [9,10] || 11;
+
+## concat scalar with 2D is not valid
+query error
+select 0 || [[1,2,3]];
+
+# array concatenate operator with column
+
+statement ok
+CREATE TABLE array_concat_operator_table
+AS VALUES
+ (0, [1, 2, 2, 3], 4, [5, 6, 5]),
+ (-1, [4, 5, 6], 7, [8, 1, 1])
+;
+
+query ?
+select column1 || column2 || column3 || column4 from array_concat_operator_table;
+----
+[0, 1, 2, 2, 3, 4, 5, 6, 5]
+[-1, 4, 5, 6, 7, 8, 1, 1]
+
+statement ok
+drop table array_concat_operator_table;
+
## array containment operator
# array containment operator with scalars #1 (at arrow)
diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt
index 4583ef319b..2a39e31388 100644
--- a/datafusion/sqllogictest/test_files/explain.slt
+++ b/datafusion/sqllogictest/test_files/explain.slt
@@ -180,6 +180,7 @@ initial_logical_plan
Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c
--TableScan: simple_explain_test
logical_plan after inline_table_scan SAME TEXT AS ABOVE
+logical_plan after operator_to_function SAME TEXT AS ABOVE
logical_plan after type_coercion SAME TEXT AS ABOVE
logical_plan after count_wildcard_rule SAME TEXT AS ABOVE
analyzed_logical_plan SAME TEXT AS ABOVE