You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/11/12 11:57:47 UTC

(arrow-datafusion) branch main updated: feat: support UDAF in substrait producer/consumer (#8119)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 824bb66370 feat: support UDAF in substrait producer/consumer (#8119)
824bb66370 is described below

commit 824bb66370eba3cd93a21b4a594315322d4c1718
Author: Ruihang Xia <wa...@gmail.com>
AuthorDate: Sun Nov 12 19:57:41 2023 +0800

    feat: support UDAF in substrait producer/consumer (#8119)
    
    * feat: support UDAF in substrait producer/consumer
    
    Signed-off-by: Ruihang Xia <wa...@gmail.com>
    
    * Update datafusion/substrait/src/logical_plan/consumer.rs
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
    
    * remove redundent to_lowercase
    
    Signed-off-by: Ruihang Xia <wa...@gmail.com>
    
    ---------
    
    Signed-off-by: Ruihang Xia <wa...@gmail.com>
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
 datafusion/substrait/src/logical_plan/consumer.rs  | 45 ++++++++++-----
 datafusion/substrait/src/logical_plan/producer.rs  | 41 +++++++++++---
 .../tests/cases/roundtrip_logical_plan.rs          | 64 +++++++++++++++++++++-
 3 files changed, 125 insertions(+), 25 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs
index c6bcbb479e..f4c36557da 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -19,6 +19,7 @@ use async_recursion::async_recursion;
 use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
 use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef};
 
+use datafusion::execution::FunctionRegistry;
 use datafusion::logical_expr::{
     aggregate_function, window_function::find_df_window_func, BinaryExpr,
     BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
@@ -365,6 +366,7 @@ pub async fn from_substrait_rel(
                                 _ => false,
                             };
                             from_substrait_agg_func(
+                                ctx,
                                 f,
                                 input.schema(),
                                 extensions,
@@ -660,6 +662,7 @@ pub async fn from_substriat_func_args(
 
 /// Convert Substrait AggregateFunction to DataFusion Expr
 pub async fn from_substrait_agg_func(
+    ctx: &SessionContext,
     f: &AggregateFunction,
     input_schema: &DFSchema,
     extensions: &HashMap<u32, &String>,
@@ -680,23 +683,37 @@ pub async fn from_substrait_agg_func(
         args.push(arg_expr?.as_ref().clone());
     }
 
-    let fun = match extensions.get(&f.function_reference) {
-        Some(function_name) => {
-            aggregate_function::AggregateFunction::from_str(function_name)
-        }
-        None => not_impl_err!(
-            "Aggregated function not found: function anchor = {:?}",
+    let Some(function_name) = extensions.get(&f.function_reference) else {
+        return plan_err!(
+            "Aggregate function not registered: function anchor = {:?}",
             f.function_reference
-        ),
+        );
     };
 
-    Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction {
-        fun: fun.unwrap(),
-        args,
-        distinct,
-        filter,
-        order_by,
-    })))
+    // try udaf first, then built-in aggr fn.
+    if let Ok(fun) = ctx.udaf(function_name) {
+        Ok(Arc::new(Expr::AggregateUDF(expr::AggregateUDF {
+            fun,
+            args,
+            filter,
+            order_by,
+        })))
+    } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name)
+    {
+        Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction {
+            fun,
+            args,
+            distinct,
+            filter,
+            order_by,
+        })))
+    } else {
+        not_impl_err!(
+            "Aggregated function {} is not supported: function anchor = {:?}",
+            function_name,
+            f.function_reference
+        )
+    }
 }
 
 /// Convert Substrait Rex to DataFusion Expr
diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs
index 142b6c3628..6fe8eca337 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -588,8 +588,7 @@ pub fn to_substrait_agg_measure(
             for arg in args {
                 arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
             }
-            let function_name = fun.to_string().to_lowercase();
-            let function_anchor = _register_function(function_name, extension_info);
+            let function_anchor = _register_function(fun.to_string(), extension_info);
             Ok(Measure {
                 measure: Some(AggregateFunction {
                     function_reference: function_anchor,
@@ -610,6 +609,34 @@ pub fn to_substrait_agg_measure(
                 }
             })
         }
+        Expr::AggregateUDF(expr::AggregateUDF{ fun, args, filter, order_by }) =>{
+            let sorts = if let Some(order_by) = order_by {
+                order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::<Result<Vec<_>>>()?
+            } else {
+                vec![]
+            };
+            let mut arguments: Vec<FunctionArgument> = vec![];
+            for arg in args {
+                arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
+            }
+            let function_anchor = _register_function(fun.name.clone(), extension_info);
+            Ok(Measure {
+                measure: Some(AggregateFunction {
+                    function_reference: function_anchor,
+                    arguments,
+                    sorts,
+                    output_type: None,
+                    invocation: AggregationInvocation::All as i32,
+                    phase: AggregationPhase::Unspecified as i32,
+                    args: vec![],
+                    options: vec![],
+                }),
+                filter: match filter {
+                    Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?),
+                    None => None
+                }
+            })
+        },
         Expr::Alias(Alias{expr,..})=> {
             to_substrait_agg_measure(expr, schema, extension_info)
         }
@@ -703,8 +730,8 @@ pub fn make_binary_op_scalar_func(
         HashMap<String, u32>,
     ),
 ) -> Expression {
-    let function_name = operator_to_name(op).to_string().to_lowercase();
-    let function_anchor = _register_function(function_name, extension_info);
+    let function_anchor =
+        _register_function(operator_to_name(op).to_string(), extension_info);
     Expression {
         rex_type: Some(RexType::ScalarFunction(ScalarFunction {
             function_reference: function_anchor,
@@ -807,8 +834,7 @@ pub fn to_substrait_rex(
                     )?)),
                 });
             }
-            let function_name = fun.to_string().to_lowercase();
-            let function_anchor = _register_function(function_name, extension_info);
+            let function_anchor = _register_function(fun.to_string(), extension_info);
             Ok(Expression {
                 rex_type: Some(RexType::ScalarFunction(ScalarFunction {
                     function_reference: function_anchor,
@@ -973,8 +999,7 @@ pub fn to_substrait_rex(
             window_frame,
         }) => {
             // function reference
-            let function_name = fun.to_string().to_lowercase();
-            let function_anchor = _register_function(function_name, extension_info);
+            let function_anchor = _register_function(fun.to_string(), extension_info);
             // arguments
             let mut arguments: Vec<FunctionArgument> = vec![];
             for arg in args {
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 582e5a5d7c..cee3a34649 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -15,6 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use datafusion::arrow::array::ArrayRef;
+use datafusion::physical_plan::Accumulator;
+use datafusion::scalar::ScalarValue;
 use datafusion_substrait::logical_plan::{
     consumer::from_substrait_plan, producer::to_substrait_plan,
 };
@@ -28,7 +31,9 @@ use datafusion::error::{DataFusionError, Result};
 use datafusion::execution::context::SessionState;
 use datafusion::execution::registry::SerializerRegistry;
 use datafusion::execution::runtime_env::RuntimeEnv;
-use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode};
+use datafusion::logical_expr::{
+    Extension, LogicalPlan, UserDefinedLogicalNode, Volatility,
+};
 use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST;
 use datafusion::prelude::*;
 
@@ -636,6 +641,56 @@ async fn extension_logical_plan() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn roundtrip_aggregate_udf() -> Result<()> {
+    #[derive(Debug)]
+    struct Dummy {}
+
+    impl Accumulator for Dummy {
+        fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
+            Ok(vec![])
+        }
+
+        fn update_batch(
+            &mut self,
+            _values: &[ArrayRef],
+        ) -> datafusion::error::Result<()> {
+            Ok(())
+        }
+
+        fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> {
+            Ok(())
+        }
+
+        fn evaluate(&self) -> datafusion::error::Result<ScalarValue> {
+            Ok(ScalarValue::Float64(None))
+        }
+
+        fn size(&self) -> usize {
+            std::mem::size_of_val(self)
+        }
+    }
+
+    let dummy_agg = create_udaf(
+        // the name; used to represent it in plan descriptions and in the registry, to use in SQL.
+        "dummy_agg",
+        // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type.
+        vec![DataType::Int64],
+        // the return type; DataFusion expects this to match the type returned by `evaluate`.
+        Arc::new(DataType::Int64),
+        Volatility::Immutable,
+        // This is the accumulator factory; DataFusion uses it to create new accumulators.
+        Arc::new(|_| Ok(Box::new(Dummy {}))),
+        // This is the description of the state. `state()` must match the types here.
+        Arc::new(vec![DataType::Float64, DataType::UInt32]),
+    );
+
+    let ctx = create_context().await?;
+    ctx.register_udaf(dummy_agg);
+
+    roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await
+}
+
 fn check_post_join_filters(rel: &Rel) -> Result<()> {
     // search for target_rel and field value in proto
     match &rel.rel_type {
@@ -772,8 +827,7 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> {
     Ok(())
 }
 
-async fn roundtrip(sql: &str) -> Result<()> {
-    let ctx = create_context().await?;
+async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> {
     let df = ctx.sql(sql).await?;
     let plan = df.into_optimized_plan()?;
     let proto = to_substrait_plan(&plan, &ctx)?;
@@ -789,6 +843,10 @@ async fn roundtrip(sql: &str) -> Result<()> {
     Ok(())
 }
 
+async fn roundtrip(sql: &str) -> Result<()> {
+    roundtrip_with_ctx(sql, create_context().await?).await
+}
+
 async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> {
     let ctx = create_context().await?;
     let df = ctx.sql(sql).await?;