You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/10/18 14:05:31 UTC

[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #3869: Optimize the `concat_ws` function

alamb commented on code in PR #3869:
URL: https://github.com/apache/arrow-datafusion/pull/3869#discussion_r998273736


##########
datafusion/expr/src/literal.rs:
##########
@@ -53,6 +53,12 @@ impl Literal for String {
     }
 }
 
+impl Literal for &String {

Review Comment:
   "TIL" `lit()` 👍 



##########
datafusion/optimizer/src/simplify_expressions.rs:
##########
@@ -880,62 +999,19 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> {
             Expr::ScalarFunction {
                 fun: BuiltinScalarFunction::Concat,
                 args,
-            } => {
-                let mut new_args = Vec::with_capacity(args.len());
-                let mut contiguous_scalar = "".to_string();
-                for e in args {
-                    match e {
-                        // All literals have been converted to Utf8 or LargeUtf8 in type_coercion.
-                        // Concatenate it with `contiguous_scalar`.
-                        Expr::Literal(
-                            ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x),
-                        ) => {
-                            if let Some(s) = x {
-                                contiguous_scalar += &s;
-                            }
-                        }
-                        // If the arg is not a literal, we should first push the current `contiguous_scalar`
-                        // to the `new_args` (if it is not empty) and reset it to empty string.
-                        // Then pushing this arg to the `new_args`.
-                        e => {
-                            if !contiguous_scalar.is_empty() {
-                                new_args.push(Expr::Literal(ScalarValue::Utf8(Some(
-                                    contiguous_scalar.clone(),
-                                ))));
-                                contiguous_scalar = "".to_string();
-                            }
-                            new_args.push(e);
-                        }
-                    }
-                }
-                if !contiguous_scalar.is_empty() {
-                    new_args
-                        .push(Expr::Literal(ScalarValue::Utf8(Some(contiguous_scalar))));
-                }
-
-                Expr::ScalarFunction {
-                    fun: BuiltinScalarFunction::Concat,
-                    args: new_args,
-                }
-            }
+            } => simpl_concat(args)?,

Review Comment:
   ❤️ 



##########
datafusion/optimizer/src/simplify_expressions.rs:
##########
@@ -256,6 +256,125 @@ fn negate_clause(expr: Expr) -> Expr {
     }
 }
 
+/// Simplify the `concat` function by
+/// 1. filtering out all `null` literals
+/// 2. concatenating contiguous literal arguments
+///
+/// For example:
+/// `concat(col(a), 'hello ', 'world', col(b), null)`
+/// will be optimized to
+/// `concat(col(a), 'hello world', col(b))`
+fn simpl_concat(args: Vec<Expr>) -> Result<Expr> {
+    let mut new_args = Vec::with_capacity(args.len());
+    let mut contiguous_scalar = "".to_string();
+    for arg in args {
+        match arg {
+            // filter out `null` args
+            Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {}
+            // All literals have been converted to Utf8 or LargeUtf8 in type_coercion.
+            // Concatenate it with the `contiguous_scalar`.
+            Expr::Literal(
+                ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)),
+            ) => contiguous_scalar += &v,
+            Expr::Literal(x) => {
+                return Err(DataFusionError::Internal(format!(
+                "The scalar {} should be casted to string type during the type coercion.",
+                x
+            )))
+            }
+            // If the arg is not a literal, we should first push the current `contiguous_scalar`
+            // to the `new_args` (if it is not empty) and reset it to empty string.
+            // Then pushing this arg to the `new_args`.
+            arg => {
+                if !contiguous_scalar.is_empty() {
+                    new_args.push(lit(contiguous_scalar));
+                    contiguous_scalar = "".to_string();
+                }
+                new_args.push(arg);
+            }
+        }
+    }
+    if !contiguous_scalar.is_empty() {
+        new_args.push(lit(contiguous_scalar));
+    }
+
+    Ok(Expr::ScalarFunction {
+        fun: BuiltinScalarFunction::Concat,
+        args: new_args,
+    })
+}
+
+/// Simply the `concat_ws` function by
+/// 1. folding to `null` if the delimiter is null
+/// 2. filtering out `null` arguments
+/// 3. using `concat` to replace `concat_ws` if the delimiter is an empty string
+/// 4. concatenating contiguous literals if the delimiter is a literal.
+fn simpl_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result<Expr> {
+    match delimiter {
+        Expr::Literal(
+            ScalarValue::Utf8(delimiter) | ScalarValue::LargeUtf8(delimiter),
+        ) => {
+            match delimiter {
+                // when the delimiter is an empty string,
+                // we can use `concat` to replace `concat_ws`
+                Some(delimiter) if delimiter.is_empty() => simpl_concat(args.to_vec()),
+                Some(delimiter) => {
+                    let mut new_args = Vec::with_capacity(args.len());
+                    new_args.push(lit(delimiter));
+                    let mut contiguous_scalar = None;
+                    for arg in args {
+                        match arg {
+                            // filter out null args
+                            Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {}
+                            Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v))) => {
+                                match contiguous_scalar {
+                                    None => contiguous_scalar = Some(v.to_string()),
+                                    Some(mut pre) => {
+                                        pre += delimiter;
+                                        pre += v;
+                                        contiguous_scalar = Some(pre)
+                                    }
+                                }
+                            }
+                            Expr::Literal(s) => return Err(DataFusionError::Internal(format!("The scalar {} should be casted to string type during the type coercion.", s))),
+                            // If the arg is not a literal, we should first push the current `contiguous_scalar`
+                            // to the `new_args` and reset it to None.
+                            // Then pushing this arg to the `new_args`.
+                            arg => {
+                                if let Some(val) = contiguous_scalar {
+                                    new_args.push(lit(val));
+                                }
+                                new_args.push(arg.clone());
+                                contiguous_scalar = None;
+                            }
+                        }
+                    }
+                    if let Some(val) = contiguous_scalar {
+                        new_args.push(lit(val));
+                    }

Review Comment:
   This pattern of creating the contiguous scalar is so similar -- I wonder if it could be extracted out into a function -- perhaps as a follow on PR



##########
datafusion/optimizer/src/simplify_expressions.rs:
##########
@@ -1379,56 +1456,81 @@ mod tests {
     }
 
     #[test]
-    fn test_simplify_concat_ws_null_separator() {
-        fn build_concat_ws_expr(args: &[Expr]) -> Expr {
-            Expr::ScalarFunction {
-                fun: BuiltinScalarFunction::ConcatWithSeparator,
-                args: args.to_vec(),
-            }
+    fn test_simplify_concat_ws() {
+        let null = Expr::Literal(ScalarValue::Utf8(None));
+        // the delimiter is not a literal
+        {
+            let expr = concat_ws(col("c"), vec![lit("a"), null.clone(), lit("b")]);
+            let expected = concat_ws(col("c"), vec![lit("a"), lit("b")]);

Review Comment:
   so cool!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org