You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/10/11 12:20:49 UTC
[arrow-datafusion] branch master updated: Add simplification rules for the `CONCAT` function (#3684)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new ac1631aa6 Add simplification rules for the `CONCAT` function (#3684)
ac1631aa6 is described below
commit ac1631aa6db411c78b38d3a70ccaeb6a89a83673
Author: Remzi Yang <59...@users.noreply.github.com>
AuthorDate: Tue Oct 11 20:20:43 2022 +0800
Add simplification rules for the `CONCAT` function (#3684)
* simpl concat
Signed-off-by: remzi <13...@gmail.com>
* update after type coercion
Signed-off-by: remzi <13...@gmail.com>
Signed-off-by: remzi <13...@gmail.com>
---
datafusion/optimizer/src/simplify_expressions.rs | 68 ++++++++++++++++++++++++
datafusion/optimizer/tests/integration-test.rs | 13 +++++
2 files changed, 81 insertions(+)
diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs
index c96f6eea7..95c640b2a 100644
--- a/datafusion/optimizer/src/simplify_expressions.rs
+++ b/datafusion/optimizer/src/simplify_expressions.rs
@@ -878,12 +878,56 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> {
out_expr.rewrite(self)?
}
+ // concat
+ ScalarFunction {
+ fun: BuiltinScalarFunction::Concat,
+ args,
+ } => {
+ let mut new_args = Vec::with_capacity(args.len());
+ let mut contiguous_scalar = "".to_string();
+ for e in args {
+ match e {
+ // All literals have been converted to Utf8 or LargeUtf8 in type_coercion.
+ // Concatenate it with `contiguous_scalar`.
+ Expr::Literal(
+ ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x),
+ ) => {
+ if let Some(s) = x {
+ contiguous_scalar += &s;
+ }
+ }
+ // If the arg is not a literal, we should first push the current `contiguous_scalar`
+ // to the `new_args` (if it is not empty) and reset it to empty string.
+ // Then pushing this arg to the `new_args`.
+ e => {
+ if !contiguous_scalar.is_empty() {
+ new_args.push(Expr::Literal(ScalarValue::Utf8(Some(
+ contiguous_scalar.clone(),
+ ))));
+ contiguous_scalar = "".to_string();
+ }
+ new_args.push(e);
+ }
+ }
+ }
+ if !contiguous_scalar.is_empty() {
+ new_args
+ .push(Expr::Literal(ScalarValue::Utf8(Some(contiguous_scalar))));
+ }
+
+ ScalarFunction {
+ fun: BuiltinScalarFunction::Concat,
+ args: new_args,
+ }
+ }
+
// concat_ws
ScalarFunction {
fun: BuiltinScalarFunction::ConcatWithSeparator,
args,
} => {
match &args[..] {
+ // concat_ws(null, ..) --> null
[Expr::Literal(sp), ..] if sp.is_null() => {
Expr::Literal(ScalarValue::Utf8(None))
}
@@ -1352,6 +1396,30 @@ mod tests {
}
}
+ #[test]
+ fn test_simplify_concat() {
+ fn build_concat_expr(args: &[Expr]) -> Expr {
+ Expr::ScalarFunction {
+ fun: BuiltinScalarFunction::Concat,
+ args: args.to_vec(),
+ }
+ }
+
+ let null = Expr::Literal(ScalarValue::Utf8(None));
+ let expr = build_concat_expr(&[
+ null.clone(),
+ col("c0"),
+ lit("hello "),
+ null.clone(),
+ lit("rust"),
+ col("c1"),
+ lit(""),
+ null,
+ ]);
+ let expected = build_concat_expr(&[col("c0"), lit("hello rust"), col("c1")]);
+ assert_eq!(simplify(expr), expected)
+ }
+
// ------------------------------
// --- ConstEvaluator tests -----
// ------------------------------
diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs
index 12a5b4447..af7bc635a 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -199,6 +199,19 @@ fn between_date64_plus_interval() -> Result<()> {
Ok(())
}
+#[test]
+fn concat_literals() -> Result<()> {
+ let sql = "SELECT concat(true, col_int32, false, null, 'hello', col_utf8, 12, 3.4) \
+ AS col
+ FROM test";
+ let plan = test_sql(sql)?;
+ let expected =
+ "Projection: concat(Utf8(\"1\"), CAST(test.col_int32 AS Utf8), Utf8(\"0hello\"), test.col_utf8, Utf8(\"123.4\")) AS col\
+ \n TableScan: test projection=[col_int32, col_utf8]";
+ assert_eq!(expected, format!("{:?}", plan));
+ Ok(())
+}
+
fn test_sql(sql: &str) -> Result<LogicalPlan> {
// parse the SQL
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...