You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/05/30 17:39:30 UTC

[arrow-datafusion] branch main updated: Substrait: Support Expr::ScalarFunction (#6471)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new fd9bdad453 Substrait: Support Expr::ScalarFunction (#6471)
fd9bdad453 is described below

commit fd9bdad45397d195334d2cb797ca50f3cc5928c3
Author: Jay Zhan <ja...@gmail.com>
AuthorDate: Wed May 31 01:39:23 2023 +0800

    Substrait: Support Expr::ScalarFunction (#6471)
    
    * Add support for abs
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    
    * fmt
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    
    * Add Op or ScalarFunction for two arguments cases
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    
    * refactor name
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    
    * address comment
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
    
    ---------
    
    Signed-off-by: jayzhan211 <ja...@gmail.com>
---
 datafusion/substrait/src/logical_plan/consumer.rs  | 129 ++++++++++++++++-----
 datafusion/substrait/src/logical_plan/producer.rs  |  28 ++++-
 .../substrait/tests/roundtrip_logical_plan.rs      |  15 +++
 3 files changed, 143 insertions(+), 29 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs
index 9343d31130..f914b62a14 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -19,8 +19,8 @@ use async_recursion::async_recursion;
 use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
 use datafusion::common::{DFField, DFSchema, DFSchemaRef};
 use datafusion::logical_expr::{
-    aggregate_function, window_function::find_df_window_func, BinaryExpr, Case, Expr,
-    LogicalPlan, Operator,
+    aggregate_function, window_function::find_df_window_func, BinaryExpr,
+    BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
 };
 use datafusion::logical_expr::{build_join_schema, Extension, LogicalPlanBuilder};
 use datafusion::logical_expr::{expr, Cast, WindowFrameBound, WindowFrameUnits};
@@ -64,6 +64,11 @@ use crate::variation_const::{
     TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
 };
 
+enum ScalarFunctionType {
+    Builtin(BuiltinScalarFunction),
+    Op(Operator),
+}
+
 pub fn name_to_op(name: &str) -> Result<Operator> {
     match name {
         "equal" => Ok(Operator::Eq),
@@ -97,6 +102,20 @@ pub fn name_to_op(name: &str) -> Result<Operator> {
     }
 }
 
+fn name_to_op_or_scalar_function(name: &str) -> Result<ScalarFunctionType> {
+    if let Ok(op) = name_to_op(name) {
+        return Ok(ScalarFunctionType::Op(op));
+    }
+
+    if let Ok(fun) = BuiltinScalarFunction::from_str(name) {
+        return Ok(ScalarFunctionType::Builtin(fun));
+    }
+
+    Err(DataFusionError::NotImplemented(format!(
+        "Unsupported function name: {name:?}"
+    )))
+}
+
 /// Convert Substrait Plan to DataFusion DataFrame
 pub async fn from_substrait_plan(
     ctx: &mut SessionContext,
@@ -727,38 +746,92 @@ pub async fn from_substrait_rex(
                 else_expr,
             })))
         }
-        Some(RexType::ScalarFunction(f)) => {
-            assert!(f.arguments.len() == 2);
-            let op = match extensions.get(&f.function_reference) {
-                Some(fname) => name_to_op(fname),
-                None => Err(DataFusionError::NotImplemented(format!(
-                    "Aggregated function not found: function reference = {:?}",
-                    f.function_reference
-                ))),
-            };
-            match (&f.arguments[0].arg_type, &f.arguments[1].arg_type) {
+        Some(RexType::ScalarFunction(f)) => match f.arguments.len() {
+            // BinaryExpr or ScalarFunction
+            2 => match (&f.arguments[0].arg_type, &f.arguments[1].arg_type) {
                 (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
-                    Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
-                        left: Box::new(
-                            from_substrait_rex(l, input_schema, extensions)
-                                .await?
-                                .as_ref()
-                                .clone(),
-                        ),
-                        op: op?,
-                        right: Box::new(
-                            from_substrait_rex(r, input_schema, extensions)
-                                .await?
-                                .as_ref()
-                                .clone(),
-                        ),
-                    })))
+                    let op_or_fun = match extensions.get(&f.function_reference) {
+                        Some(fname) => name_to_op_or_scalar_function(fname),
+                        None => Err(DataFusionError::NotImplemented(format!(
+                            "Aggregated function not found: function reference = {:?}",
+                            f.function_reference
+                        ))),
+                    };
+                    match op_or_fun {
+                        Ok(ScalarFunctionType::Op(op)) => {
+                            return Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
+                                left: Box::new(
+                                    from_substrait_rex(l, input_schema, extensions)
+                                        .await?
+                                        .as_ref()
+                                        .clone(),
+                                ),
+                                op,
+                                right: Box::new(
+                                    from_substrait_rex(r, input_schema, extensions)
+                                        .await?
+                                        .as_ref()
+                                        .clone(),
+                                ),
+                            })))
+                        }
+                        Ok(ScalarFunctionType::Builtin(fun)) => {
+                            Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction {
+                                fun,
+                                args: vec![
+                                    from_substrait_rex(l, input_schema, extensions)
+                                        .await?
+                                        .as_ref()
+                                        .clone(),
+                                    from_substrait_rex(r, input_schema, extensions)
+                                        .await?
+                                        .as_ref()
+                                        .clone(),
+                                ],
+                            })))
+                        }
+                        Err(e) => Err(e),
+                    }
                 }
                 (l, r) => Err(DataFusionError::NotImplemented(format!(
                     "Invalid arguments for binary expression: {l:?} and {r:?}"
                 ))),
+            },
+            // ScalarFunction
+            _ => {
+                let fun = match extensions.get(&f.function_reference) {
+                    Some(fname) => BuiltinScalarFunction::from_str(fname),
+                    None => Err(DataFusionError::NotImplemented(format!(
+                        "Aggregated function not found: function reference = {:?}",
+                        f.function_reference
+                    ))),
+                };
+
+                let mut args: Vec<Expr> = vec![];
+                for arg in f.arguments.iter() {
+                    match &arg.arg_type {
+                        Some(ArgType::Value(e)) => {
+                            args.push(
+                                from_substrait_rex(e, input_schema, extensions)
+                                    .await?
+                                    .as_ref()
+                                    .clone(),
+                            );
+                        }
+                        e => {
+                            return Err(DataFusionError::NotImplemented(format!(
+                                "Invalid arguments for scalar function: {e:?}"
+                            )))
+                        }
+                    }
+                }
+
+                Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction {
+                    fun: fun?,
+                    args,
+                })))
             }
-        }
+        },
         Some(RexType::Literal(lit)) => {
             let scalar_value = from_substrait_literal(lit)?;
             Ok(Arc::new(Expr::Literal(scalar_value)))
diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs
index 872760390a..785bfa4ea6 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -28,7 +28,9 @@ use datafusion::{
 use datafusion::common::DFSchemaRef;
 #[allow(unused_imports)]
 use datafusion::logical_expr::aggregate_function;
-use datafusion::logical_expr::expr::{BinaryExpr, Case, Cast, Sort, WindowFunction};
+use datafusion::logical_expr::expr::{
+    BinaryExpr, Case, Cast, ScalarFunction as DFScalarFunction, Sort, WindowFunction,
+};
 use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator};
 use datafusion::prelude::{binary_expr, Expr};
 use prost_types::Any as ProtoAny;
@@ -564,6 +566,7 @@ pub fn make_binary_op_scalar_func(
 }
 
 /// Convert DataFusion Expr to Substrait Rex
+#[allow(deprecated)]
 pub fn to_substrait_rex(
     expr: &Expr,
     schema: &DFSchemaRef,
@@ -573,6 +576,29 @@ pub fn to_substrait_rex(
     ),
 ) -> Result<Expression> {
     match expr {
+        Expr::ScalarFunction(DFScalarFunction { fun, args }) => {
+            let mut arguments: Vec<FunctionArgument> = vec![];
+            for arg in args {
+                arguments.push(FunctionArgument {
+                    arg_type: Some(ArgType::Value(to_substrait_rex(
+                        arg,
+                        schema,
+                        extension_info,
+                    )?)),
+                });
+            }
+            let function_name = fun.to_string().to_lowercase();
+            let function_anchor = _register_function(function_name, extension_info);
+            Ok(Expression {
+                rex_type: Some(RexType::ScalarFunction(ScalarFunction {
+                    function_reference: function_anchor,
+                    arguments,
+                    output_type: None,
+                    args: vec![],
+                    options: vec![],
+                })),
+            })
+        }
         Expr::Between(Between {
             expr,
             negated,
diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs b/datafusion/substrait/tests/roundtrip_logical_plan.rs
index 77a02cbf47..8cdf89b294 100644
--- a/datafusion/substrait/tests/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs
@@ -278,6 +278,21 @@ mod tests {
         .await
     }
 
+    #[tokio::test]
+    async fn simple_scalar_function_abs() -> Result<()> {
+        roundtrip("SELECT ABS(a) FROM data").await
+    }
+
+    #[tokio::test]
+    async fn simple_scalar_function_pow() -> Result<()> {
+        roundtrip("SELECT POW(a, 2) FROM data").await
+    }
+
+    #[tokio::test]
+    async fn simple_scalar_function_substr() -> Result<()> {
+        roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await
+    }
+
     #[tokio::test]
     async fn case_without_base_expression() -> Result<()> {
         roundtrip(