You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/10/11 12:20:49 UTC

[arrow-datafusion] branch master updated: Add simplification rules for the `CONCAT` function (#3684)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new ac1631aa6 Add simplification rules for the `CONCAT` function (#3684)
ac1631aa6 is described below

commit ac1631aa6db411c78b38d3a70ccaeb6a89a83673
Author: Remzi Yang <59...@users.noreply.github.com>
AuthorDate: Tue Oct 11 20:20:43 2022 +0800

    Add simplification rules for the `CONCAT` function (#3684)
    
    * simpl concat
    
    Signed-off-by: remzi <13...@gmail.com>
    
    * update after type coercion
    
    Signed-off-by: remzi <13...@gmail.com>
    
    Signed-off-by: remzi <13...@gmail.com>
---
 datafusion/optimizer/src/simplify_expressions.rs | 68 ++++++++++++++++++++++++
 datafusion/optimizer/tests/integration-test.rs   | 13 +++++
 2 files changed, 81 insertions(+)

diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs
index c96f6eea7..95c640b2a 100644
--- a/datafusion/optimizer/src/simplify_expressions.rs
+++ b/datafusion/optimizer/src/simplify_expressions.rs
@@ -878,12 +878,56 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> {
                 out_expr.rewrite(self)?
             }
 
+            // concat
+            ScalarFunction {
+                fun: BuiltinScalarFunction::Concat,
+                args,
+            } => {
+                let mut new_args = Vec::with_capacity(args.len());
+                let mut contiguous_scalar = "".to_string();
+                for e in args {
+                    match e {
+                        // All literals have been converted to Utf8 or LargeUtf8 in type_coercion.
+                        // Concatenate it with `contiguous_scalar`.
+                        Expr::Literal(
+                            ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x),
+                        ) => {
+                            if let Some(s) = x {
+                                contiguous_scalar += &s;
+                            }
+                        }
+                        // If the arg is not a literal, we should first push the current `contiguous_scalar`
+                        // to the `new_args` (if it is not empty) and reset it to empty string.
+                        // Then pushing this arg to the `new_args`.
+                        e => {
+                            if !contiguous_scalar.is_empty() {
+                                new_args.push(Expr::Literal(ScalarValue::Utf8(Some(
+                                    contiguous_scalar.clone(),
+                                ))));
+                                contiguous_scalar = "".to_string();
+                            }
+                            new_args.push(e);
+                        }
+                    }
+                }
+                if !contiguous_scalar.is_empty() {
+                    new_args
+                        .push(Expr::Literal(ScalarValue::Utf8(Some(contiguous_scalar))));
+                }
+
+                ScalarFunction {
+                    fun: BuiltinScalarFunction::Concat,
+                    args: new_args,
+                }
+            }
+
             // concat_ws
             ScalarFunction {
                 fun: BuiltinScalarFunction::ConcatWithSeparator,
                 args,
             } => {
                 match &args[..] {
+                    // concat_ws(null, ..) --> null
                     [Expr::Literal(sp), ..] if sp.is_null() => {
                         Expr::Literal(ScalarValue::Utf8(None))
                     }
@@ -1352,6 +1396,30 @@ mod tests {
         }
     }
 
+    #[test]
+    fn test_simplify_concat() {
+        fn build_concat_expr(args: &[Expr]) -> Expr {
+            Expr::ScalarFunction {
+                fun: BuiltinScalarFunction::Concat,
+                args: args.to_vec(),
+            }
+        }
+
+        let null = Expr::Literal(ScalarValue::Utf8(None));
+        let expr = build_concat_expr(&[
+            null.clone(),
+            col("c0"),
+            lit("hello "),
+            null.clone(),
+            lit("rust"),
+            col("c1"),
+            lit(""),
+            null,
+        ]);
+        let expected = build_concat_expr(&[col("c0"), lit("hello rust"), col("c1")]);
+        assert_eq!(simplify(expr), expected)
+    }
+
     // ------------------------------
     // --- ConstEvaluator tests -----
     // ------------------------------
diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs
index 12a5b4447..af7bc635a 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -199,6 +199,19 @@ fn between_date64_plus_interval() -> Result<()> {
     Ok(())
 }
 
+#[test]
+fn concat_literals() -> Result<()> {
+    let sql = "SELECT concat(true, col_int32, false, null, 'hello', col_utf8, 12, 3.4) \
+        AS col
+        FROM test";
+    let plan = test_sql(sql)?;
+    let expected =
+        "Projection: concat(Utf8(\"1\"), CAST(test.col_int32 AS Utf8), Utf8(\"0hello\"), test.col_utf8, Utf8(\"123.4\")) AS col\
+        \n  TableScan: test projection=[col_int32, col_utf8]";
+    assert_eq!(expected, format!("{:?}", plan));
+    Ok(())
+}
+
 fn test_sql(sql: &str) -> Result<LogicalPlan> {
     // parse the SQL
     let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...