You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/11/12 11:57:47 UTC
(arrow-datafusion) branch main updated: feat: support UDAF in substrait producer/consumer (#8119)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 824bb66370 feat: support UDAF in substrait producer/consumer (#8119)
824bb66370 is described below
commit 824bb66370eba3cd93a21b4a594315322d4c1718
Author: Ruihang Xia <wa...@gmail.com>
AuthorDate: Sun Nov 12 19:57:41 2023 +0800
feat: support UDAF in substrait producer/consumer (#8119)
* feat: support UDAF in substrait producer/consumer
Signed-off-by: Ruihang Xia <wa...@gmail.com>
* Update datafusion/substrait/src/logical_plan/consumer.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* remove redundent to_lowercase
Signed-off-by: Ruihang Xia <wa...@gmail.com>
---------
Signed-off-by: Ruihang Xia <wa...@gmail.com>
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
datafusion/substrait/src/logical_plan/consumer.rs | 45 ++++++++++-----
datafusion/substrait/src/logical_plan/producer.rs | 41 +++++++++++---
.../tests/cases/roundtrip_logical_plan.rs | 64 +++++++++++++++++++++-
3 files changed, 125 insertions(+), 25 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs
index c6bcbb479e..f4c36557da 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -19,6 +19,7 @@ use async_recursion::async_recursion;
use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef};
+use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
aggregate_function, window_function::find_df_window_func, BinaryExpr,
BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
@@ -365,6 +366,7 @@ pub async fn from_substrait_rel(
_ => false,
};
from_substrait_agg_func(
+ ctx,
f,
input.schema(),
extensions,
@@ -660,6 +662,7 @@ pub async fn from_substriat_func_args(
/// Convert Substrait AggregateFunction to DataFusion Expr
pub async fn from_substrait_agg_func(
+ ctx: &SessionContext,
f: &AggregateFunction,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
@@ -680,23 +683,37 @@ pub async fn from_substrait_agg_func(
args.push(arg_expr?.as_ref().clone());
}
- let fun = match extensions.get(&f.function_reference) {
- Some(function_name) => {
- aggregate_function::AggregateFunction::from_str(function_name)
- }
- None => not_impl_err!(
- "Aggregated function not found: function anchor = {:?}",
+ let Some(function_name) = extensions.get(&f.function_reference) else {
+ return plan_err!(
+ "Aggregate function not registered: function anchor = {:?}",
f.function_reference
- ),
+ );
};
- Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction {
- fun: fun.unwrap(),
- args,
- distinct,
- filter,
- order_by,
- })))
+ // try udaf first, then built-in aggr fn.
+ if let Ok(fun) = ctx.udaf(function_name) {
+ Ok(Arc::new(Expr::AggregateUDF(expr::AggregateUDF {
+ fun,
+ args,
+ filter,
+ order_by,
+ })))
+ } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name)
+ {
+ Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction {
+ fun,
+ args,
+ distinct,
+ filter,
+ order_by,
+ })))
+ } else {
+ not_impl_err!(
+ "Aggregated function {} is not supported: function anchor = {:?}",
+ function_name,
+ f.function_reference
+ )
+ }
}
/// Convert Substrait Rex to DataFusion Expr
diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs
index 142b6c3628..6fe8eca337 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -588,8 +588,7 @@ pub fn to_substrait_agg_measure(
for arg in args {
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
}
- let function_name = fun.to_string().to_lowercase();
- let function_anchor = _register_function(function_name, extension_info);
+ let function_anchor = _register_function(fun.to_string(), extension_info);
Ok(Measure {
measure: Some(AggregateFunction {
function_reference: function_anchor,
@@ -610,6 +609,34 @@ pub fn to_substrait_agg_measure(
}
})
}
+ Expr::AggregateUDF(expr::AggregateUDF{ fun, args, filter, order_by }) =>{
+ let sorts = if let Some(order_by) = order_by {
+ order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::<Result<Vec<_>>>()?
+ } else {
+ vec![]
+ };
+ let mut arguments: Vec<FunctionArgument> = vec![];
+ for arg in args {
+ arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
+ }
+ let function_anchor = _register_function(fun.name.clone(), extension_info);
+ Ok(Measure {
+ measure: Some(AggregateFunction {
+ function_reference: function_anchor,
+ arguments,
+ sorts,
+ output_type: None,
+ invocation: AggregationInvocation::All as i32,
+ phase: AggregationPhase::Unspecified as i32,
+ args: vec![],
+ options: vec![],
+ }),
+ filter: match filter {
+ Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?),
+ None => None
+ }
+ })
+ },
Expr::Alias(Alias{expr,..})=> {
to_substrait_agg_measure(expr, schema, extension_info)
}
@@ -703,8 +730,8 @@ pub fn make_binary_op_scalar_func(
HashMap<String, u32>,
),
) -> Expression {
- let function_name = operator_to_name(op).to_string().to_lowercase();
- let function_anchor = _register_function(function_name, extension_info);
+ let function_anchor =
+ _register_function(operator_to_name(op).to_string(), extension_info);
Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
@@ -807,8 +834,7 @@ pub fn to_substrait_rex(
)?)),
});
}
- let function_name = fun.to_string().to_lowercase();
- let function_anchor = _register_function(function_name, extension_info);
+ let function_anchor = _register_function(fun.to_string(), extension_info);
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
@@ -973,8 +999,7 @@ pub fn to_substrait_rex(
window_frame,
}) => {
// function reference
- let function_name = fun.to_string().to_lowercase();
- let function_anchor = _register_function(function_name, extension_info);
+ let function_anchor = _register_function(fun.to_string(), extension_info);
// arguments
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 582e5a5d7c..cee3a34649 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -15,6 +15,9 @@
// specific language governing permissions and limitations
// under the License.
+use datafusion::arrow::array::ArrayRef;
+use datafusion::physical_plan::Accumulator;
+use datafusion::scalar::ScalarValue;
use datafusion_substrait::logical_plan::{
consumer::from_substrait_plan, producer::to_substrait_plan,
};
@@ -28,7 +31,9 @@ use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::SessionState;
use datafusion::execution::registry::SerializerRegistry;
use datafusion::execution::runtime_env::RuntimeEnv;
-use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode};
+use datafusion::logical_expr::{
+ Extension, LogicalPlan, UserDefinedLogicalNode, Volatility,
+};
use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST;
use datafusion::prelude::*;
@@ -636,6 +641,56 @@ async fn extension_logical_plan() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn roundtrip_aggregate_udf() -> Result<()> {
+ #[derive(Debug)]
+ struct Dummy {}
+
+ impl Accumulator for Dummy {
+ fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
+ Ok(vec![])
+ }
+
+ fn update_batch(
+ &mut self,
+ _values: &[ArrayRef],
+ ) -> datafusion::error::Result<()> {
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> {
+ Ok(())
+ }
+
+ fn evaluate(&self) -> datafusion::error::Result<ScalarValue> {
+ Ok(ScalarValue::Float64(None))
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+ }
+
+ let dummy_agg = create_udaf(
+ // the name; used to represent it in plan descriptions and in the registry, to use in SQL.
+ "dummy_agg",
+ // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type.
+ vec![DataType::Int64],
+ // the return type; DataFusion expects this to match the type returned by `evaluate`.
+ Arc::new(DataType::Int64),
+ Volatility::Immutable,
+ // This is the accumulator factory; DataFusion uses it to create new accumulators.
+ Arc::new(|_| Ok(Box::new(Dummy {}))),
+ // This is the description of the state. `state()` must match the types here.
+ Arc::new(vec![DataType::Float64, DataType::UInt32]),
+ );
+
+ let ctx = create_context().await?;
+ ctx.register_udaf(dummy_agg);
+
+ roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await
+}
+
fn check_post_join_filters(rel: &Rel) -> Result<()> {
// search for target_rel and field value in proto
match &rel.rel_type {
@@ -772,8 +827,7 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> {
Ok(())
}
-async fn roundtrip(sql: &str) -> Result<()> {
- let ctx = create_context().await?;
+async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> {
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
let proto = to_substrait_plan(&plan, &ctx)?;
@@ -789,6 +843,10 @@ async fn roundtrip(sql: &str) -> Result<()> {
Ok(())
}
+async fn roundtrip(sql: &str) -> Result<()> {
+ roundtrip_with_ctx(sql, create_context().await?).await
+}
+
async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> {
let ctx = create_context().await?;
let df = ctx.sql(sql).await?;