You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/11/30 22:05:03 UTC
(arrow-datafusion) branch main updated: Refactor aggregate function handling (#8358)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new a49740f675 Refactor aggregate function handling (#8358)
a49740f675 is described below
commit a49740f675b2279e60b1898114f2e4d81ed43441
Author: Alex Huang <hu...@gmail.com>
AuthorDate: Thu Nov 30 23:04:56 2023 +0100
Refactor aggregate function handling (#8358)
* Refactor aggregate function handling
* fix ci
* update comment
* fix ci
* simplify the code
* fix fmt
* fix ci
* fix clippy
---
datafusion/core/src/datasource/listing/helpers.rs | 3 +-
datafusion/core/src/physical_planner.rs | 128 +++++------
datafusion/expr/src/aggregate_function.rs | 2 +-
datafusion/expr/src/expr.rs | 112 ++++++----
datafusion/expr/src/expr_schema.rs | 26 ++-
datafusion/expr/src/tree_node/expr.rs | 66 +++---
datafusion/expr/src/udaf.rs | 11 +-
datafusion/expr/src/utils.rs | 10 +-
.../optimizer/src/analyzer/count_wildcard_rule.rs | 15 +-
datafusion/optimizer/src/analyzer/type_coercion.rs | 68 +++---
.../optimizer/src/common_subexpr_eliminate.rs | 11 +-
datafusion/optimizer/src/decorrelate.rs | 27 ++-
datafusion/optimizer/src/push_down_filter.rs | 1 -
.../src/simplify_expressions/expr_simplifier.rs | 1 -
.../optimizer/src/single_distinct_to_groupby.rs | 5 +-
datafusion/proto/src/logical_plan/from_proto.rs | 3 +-
datafusion/proto/src/logical_plan/to_proto.rs | 246 +++++++++++----------
.../proto/tests/cases/roundtrip_logical_plan.rs | 3 +-
datafusion/sql/src/expr/function.rs | 4 +-
datafusion/sql/src/expr/mod.rs | 3 +-
datafusion/sql/src/select.rs | 9 +-
datafusion/substrait/src/logical_plan/consumer.rs | 19 +-
datafusion/substrait/src/logical_plan/producer.rs | 126 ++++++-----
23 files changed, 462 insertions(+), 437 deletions(-)
diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs
index f9b02f4d0c..0c39877cd1 100644
--- a/datafusion/core/src/datasource/listing/helpers.rs
+++ b/datafusion/core/src/datasource/listing/helpers.rs
@@ -122,8 +122,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
// - AGGREGATE, WINDOW and SORT should not end up in filter conditions, except maybe in some edge cases
// - Can `Wildcard` be considered as a `Literal`?
// - ScalarVariable could be `applicable`, but that would require access to the context
- Expr::AggregateUDF { .. }
- | Expr::AggregateFunction { .. }
+ Expr::AggregateFunction { .. }
| Expr::Sort { .. }
| Expr::WindowFunction { .. }
| Expr::Wildcard { .. }
diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs
index ef364c22ee..9e64eb9c51 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -82,8 +82,9 @@ use datafusion_common::{
};
use datafusion_expr::dml::{CopyOptions, CopyTo};
use datafusion_expr::expr::{
- self, AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Cast,
- GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, WindowFunction,
+ self, AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr,
+ Cast, GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast,
+ WindowFunction,
};
use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols};
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
@@ -229,30 +230,37 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
create_function_physical_name(&fun.to_string(), false, args)
}
Expr::AggregateFunction(AggregateFunction {
- fun,
+ func_def,
distinct,
args,
- ..
- }) => create_function_physical_name(&fun.to_string(), *distinct, args),
- Expr::AggregateUDF(AggregateUDF {
- fun,
- args,
filter,
order_by,
- }) => {
- // TODO: Add support for filter and order by in AggregateUDF
- if filter.is_some() {
- return exec_err!("aggregate expression with filter is not supported");
+ }) => match func_def {
+ AggregateFunctionDefinition::BuiltIn(..) => {
+ create_function_physical_name(func_def.name(), *distinct, args)
}
- if order_by.is_some() {
- return exec_err!("aggregate expression with order_by is not supported");
+ AggregateFunctionDefinition::UDF(fun) => {
+ // TODO: Add support for filter and order by in AggregateUDF
+ if filter.is_some() {
+ return exec_err!(
+ "aggregate expression with filter is not supported"
+ );
+ }
+ if order_by.is_some() {
+ return exec_err!(
+ "aggregate expression with order_by is not supported"
+ );
+ }
+ let names = args
+ .iter()
+ .map(|e| create_physical_name(e, false))
+ .collect::<Result<Vec<_>>>()?;
+ Ok(format!("{}({})", fun.name(), names.join(",")))
}
- let mut names = Vec::with_capacity(args.len());
- for e in args {
- names.push(create_physical_name(e, false)?);
+ AggregateFunctionDefinition::Name(_) => {
+ internal_err!("Aggregate function `Expr` with name should be resolved.")
}
- Ok(format!("{}({})", fun.name(), names.join(",")))
- }
+ },
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => Ok(format!(
"ROLLUP ({})",
@@ -1705,7 +1713,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
) -> Result<AggregateExprWithOptionalArgs> {
match e {
Expr::AggregateFunction(AggregateFunction {
- fun,
+ func_def,
distinct,
args,
filter,
@@ -1746,63 +1754,35 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
),
None => None,
};
- let ordering_reqs = order_by.clone().unwrap_or(vec![]);
- let agg_expr = aggregates::create_aggregate_expr(
- fun,
- *distinct,
- &args,
- &ordering_reqs,
- physical_input_schema,
- name,
- )?;
- Ok((agg_expr, filter, order_by))
- }
- Expr::AggregateUDF(AggregateUDF {
- fun,
- args,
- filter,
- order_by,
- }) => {
- let args = args
- .iter()
- .map(|e| {
- create_physical_expr(
- e,
- logical_input_schema,
+ let (agg_expr, filter, order_by) = match func_def {
+ AggregateFunctionDefinition::BuiltIn(fun) => {
+ let ordering_reqs = order_by.clone().unwrap_or(vec![]);
+ let agg_expr = aggregates::create_aggregate_expr(
+ fun,
+ *distinct,
+ &args,
+ &ordering_reqs,
physical_input_schema,
- execution_props,
+ name,
+ )?;
+ (agg_expr, filter, order_by)
+ }
+ AggregateFunctionDefinition::UDF(fun) => {
+ let agg_expr = udaf::create_aggregate_expr(
+ fun,
+ &args,
+ physical_input_schema,
+ name,
+ );
+ (agg_expr?, filter, order_by)
+ }
+ AggregateFunctionDefinition::Name(_) => {
+ return internal_err!(
+ "Aggregate function name should have been resolved"
)
- })
- .collect::<Result<Vec<_>>>()?;
-
- let filter = match filter {
- Some(e) => Some(create_physical_expr(
- e,
- logical_input_schema,
- physical_input_schema,
- execution_props,
- )?),
- None => None,
- };
- let order_by = match order_by {
- Some(e) => Some(
- e.iter()
- .map(|expr| {
- create_physical_sort_expr(
- expr,
- logical_input_schema,
- physical_input_schema,
- execution_props,
- )
- })
- .collect::<Result<Vec<_>>>()?,
- ),
- None => None,
+ }
};
-
- let agg_expr =
- udaf::create_aggregate_expr(fun, &args, physical_input_schema, name);
- Ok((agg_expr?, filter, order_by))
+ Ok((agg_expr, filter, order_by))
}
other => internal_err!("Invalid aggregate expression '{other:?}'"),
}
diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs
index 4611c7fb10..cea72c3cb5 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -105,7 +105,7 @@ pub enum AggregateFunction {
}
impl AggregateFunction {
- fn name(&self) -> &str {
+ pub fn name(&self) -> &str {
use AggregateFunction::*;
match self {
Count => "COUNT",
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index b46d204faa..256f5b210e 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -154,8 +154,6 @@ pub enum Expr {
AggregateFunction(AggregateFunction),
/// Represents the call of a window function with arguments.
WindowFunction(WindowFunction),
- /// aggregate function
- AggregateUDF(AggregateUDF),
/// Returns whether the list contains the expr value.
InList(InList),
/// EXISTS subquery
@@ -484,11 +482,33 @@ impl Sort {
}
}
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+/// Defines which implementation of an aggregate function DataFusion should call.
+pub enum AggregateFunctionDefinition {
+ BuiltIn(aggregate_function::AggregateFunction),
+ /// Resolved to a user defined aggregate function
+ UDF(Arc<crate::AggregateUDF>),
+ /// A aggregation function constructed with name. This variant can not be executed directly
+ /// and instead must be resolved to one of the other variants prior to physical planning.
+ Name(Arc<str>),
+}
+
+impl AggregateFunctionDefinition {
+ /// Function's name for display
+ pub fn name(&self) -> &str {
+ match self {
+ AggregateFunctionDefinition::BuiltIn(fun) => fun.name(),
+ AggregateFunctionDefinition::UDF(udf) => udf.name(),
+ AggregateFunctionDefinition::Name(func_name) => func_name.as_ref(),
+ }
+ }
+}
+
/// Aggregate function
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct AggregateFunction {
/// Name of the function
- pub fun: aggregate_function::AggregateFunction,
+ pub func_def: AggregateFunctionDefinition,
/// List of expressions to feed to the functions as arguments
pub args: Vec<Expr>,
/// Whether this is a DISTINCT aggregation or not
@@ -508,7 +528,24 @@ impl AggregateFunction {
order_by: Option<Vec<Expr>>,
) -> Self {
Self {
- fun,
+ func_def: AggregateFunctionDefinition::BuiltIn(fun),
+ args,
+ distinct,
+ filter,
+ order_by,
+ }
+ }
+
+ /// Create a new AggregateFunction expression with a user-defined function (UDF)
+ pub fn new_udf(
+ udf: Arc<crate::AggregateUDF>,
+ args: Vec<Expr>,
+ distinct: bool,
+ filter: Option<Box<Expr>>,
+ order_by: Option<Vec<Expr>>,
+ ) -> Self {
+ Self {
+ func_def: AggregateFunctionDefinition::UDF(udf),
args,
distinct,
filter,
@@ -736,7 +773,6 @@ impl Expr {
pub fn variant_name(&self) -> &str {
match self {
Expr::AggregateFunction { .. } => "AggregateFunction",
- Expr::AggregateUDF { .. } => "AggregateUDF",
Expr::Alias(..) => "Alias",
Expr::Between { .. } => "Between",
Expr::BinaryExpr { .. } => "BinaryExpr",
@@ -1251,30 +1287,14 @@ impl fmt::Display for Expr {
Ok(())
}
Expr::AggregateFunction(AggregateFunction {
- fun,
+ func_def,
distinct,
ref args,
filter,
order_by,
..
}) => {
- fmt_function(f, &fun.to_string(), *distinct, args, true)?;
- if let Some(fe) = filter {
- write!(f, " FILTER (WHERE {fe})")?;
- }
- if let Some(ob) = order_by {
- write!(f, " ORDER BY [{}]", expr_vec_fmt!(ob))?;
- }
- Ok(())
- }
- Expr::AggregateUDF(AggregateUDF {
- fun,
- ref args,
- filter,
- order_by,
- ..
- }) => {
- fmt_function(f, fun.name(), false, args, true)?;
+ fmt_function(f, func_def.name(), *distinct, args, true)?;
if let Some(fe) = filter {
write!(f, " FILTER (WHERE {fe})")?;
}
@@ -1579,39 +1599,39 @@ fn create_name(e: &Expr) -> Result<String> {
Ok(parts.join(" "))
}
Expr::AggregateFunction(AggregateFunction {
- fun,
+ func_def,
distinct,
args,
filter,
order_by,
}) => {
- let mut name = create_function_name(&fun.to_string(), *distinct, args)?;
- if let Some(fe) = filter {
- name = format!("{name} FILTER (WHERE {fe})");
- };
- if let Some(order_by) = order_by {
- name = format!("{name} ORDER BY [{}]", expr_vec_fmt!(order_by));
+ let name = match func_def {
+ AggregateFunctionDefinition::BuiltIn(..)
+ | AggregateFunctionDefinition::Name(..) => {
+ create_function_name(func_def.name(), *distinct, args)?
+ }
+ AggregateFunctionDefinition::UDF(..) => {
+ let names: Vec<String> =
+ args.iter().map(create_name).collect::<Result<_>>()?;
+ names.join(",")
+ }
};
- Ok(name)
- }
- Expr::AggregateUDF(AggregateUDF {
- fun,
- args,
- filter,
- order_by,
- }) => {
- let mut names = Vec::with_capacity(args.len());
- for e in args {
- names.push(create_name(e)?);
- }
let mut info = String::new();
if let Some(fe) = filter {
info += &format!(" FILTER (WHERE {fe})");
+ };
+ if let Some(order_by) = order_by {
+ info += &format!(" ORDER BY [{}]", expr_vec_fmt!(order_by));
+ };
+ match func_def {
+ AggregateFunctionDefinition::BuiltIn(..)
+ | AggregateFunctionDefinition::Name(..) => {
+ Ok(format!("{}{}", name, info))
+ }
+ AggregateFunctionDefinition::UDF(fun) => {
+ Ok(format!("{}({}){}", fun.name(), name, info))
+ }
}
- if let Some(ob) = order_by {
- info += &format!(" ORDER BY ([{}])", expr_vec_fmt!(ob));
- }
- Ok(format!("{}({}){}", fun.name(), names.join(","), info))
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs
index d5d9c848b2..99b27e8912 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -17,8 +17,8 @@
use super::{Between, Expr, Like};
use crate::expr::{
- AggregateFunction, AggregateUDF, Alias, BinaryExpr, Cast, GetFieldAccess,
- GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction,
+ AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast,
+ GetFieldAccess, GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction,
ScalarFunctionDefinition, Sort, TryCast, WindowFunction,
};
use crate::field_util::GetFieldAccessSchema;
@@ -123,19 +123,22 @@ impl ExprSchemable for Expr {
.collect::<Result<Vec<_>>>()?;
fun.return_type(&data_types)
}
- Expr::AggregateFunction(AggregateFunction { fun, args, .. }) => {
+ Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
- fun.return_type(&data_types)
- }
- Expr::AggregateUDF(AggregateUDF { fun, args, .. }) => {
- let data_types = args
- .iter()
- .map(|e| e.get_type(schema))
- .collect::<Result<Vec<_>>>()?;
- fun.return_type(&data_types)
+ match func_def {
+ AggregateFunctionDefinition::BuiltIn(fun) => {
+ fun.return_type(&data_types)
+ }
+ AggregateFunctionDefinition::UDF(fun) => {
+ Ok(fun.return_type(&data_types)?)
+ }
+ AggregateFunctionDefinition::Name(_) => {
+ internal_err!("Function `Expr` with name should be resolved.")
+ }
+ }
}
Expr::Not(_)
| Expr::IsNull(_)
@@ -252,7 +255,6 @@ impl ExprSchemable for Expr {
| Expr::ScalarFunction(..)
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
- | Expr::AggregateUDF { .. }
| Expr::Placeholder(_) => Ok(true),
Expr::IsNull(_)
| Expr::IsNotNull(_)
diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs
index 474b5f7689..fcb0a4cd93 100644
--- a/datafusion/expr/src/tree_node/expr.rs
+++ b/datafusion/expr/src/tree_node/expr.rs
@@ -18,9 +18,9 @@
//! Tree node implementation for logical expr
use crate::expr::{
- AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Case, Cast,
- GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction,
- ScalarFunctionDefinition, Sort, TryCast, WindowFunction,
+ AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case,
+ Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder,
+ ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, WindowFunction,
};
use crate::{Expr, GetFieldAccess};
@@ -108,7 +108,7 @@ impl TreeNode for Expr {
expr_vec
}
Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. })
- | Expr::AggregateUDF(AggregateUDF { args, filter, order_by, .. }) => {
+ => {
let mut expr_vec = args.clone();
if let Some(f) = filter {
@@ -304,17 +304,40 @@ impl TreeNode for Expr {
)),
Expr::AggregateFunction(AggregateFunction {
args,
- fun,
+ func_def,
distinct,
filter,
order_by,
- }) => Expr::AggregateFunction(AggregateFunction::new(
- fun,
- transform_vec(args, &mut transform)?,
- distinct,
- transform_option_box(filter, &mut transform)?,
- transform_option_vec(order_by, &mut transform)?,
- )),
+ }) => match func_def {
+ AggregateFunctionDefinition::BuiltIn(fun) => {
+ Expr::AggregateFunction(AggregateFunction::new(
+ fun,
+ transform_vec(args, &mut transform)?,
+ distinct,
+ transform_option_box(filter, &mut transform)?,
+ transform_option_vec(order_by, &mut transform)?,
+ ))
+ }
+ AggregateFunctionDefinition::UDF(fun) => {
+ let order_by = if let Some(order_by) = order_by {
+ Some(transform_vec(order_by, &mut transform)?)
+ } else {
+ None
+ };
+ Expr::AggregateFunction(AggregateFunction::new_udf(
+ fun,
+ transform_vec(args, &mut transform)?,
+ false,
+ transform_option_box(filter, &mut transform)?,
+ transform_option_vec(order_by, &mut transform)?,
+ ))
+ }
+ AggregateFunctionDefinition::Name(_) => {
+ return internal_err!(
+ "Function `Expr` with name should be resolved."
+ );
+ }
+ },
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup(
transform_vec(exprs, &mut transform)?,
@@ -331,24 +354,7 @@ impl TreeNode for Expr {
))
}
},
- Expr::AggregateUDF(AggregateUDF {
- args,
- fun,
- filter,
- order_by,
- }) => {
- let order_by = if let Some(order_by) = order_by {
- Some(transform_vec(order_by, &mut transform)?)
- } else {
- None
- };
- Expr::AggregateUDF(AggregateUDF::new(
- fun,
- transform_vec(args, &mut transform)?,
- transform_option_box(filter, &mut transform)?,
- transform_option_vec(order_by, &mut transform)?,
- ))
- }
+
Expr::InList(InList {
expr,
list,
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index b06e97acc2..cfbca4ab13 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -107,12 +107,13 @@ impl AggregateUDF {
/// This utility allows using the UDAF without requiring access to
/// the registry, such as with the DataFrame API.
pub fn call(&self, args: Vec<Expr>) -> Expr {
- Expr::AggregateUDF(crate::expr::AggregateUDF {
- fun: Arc::new(self.clone()),
+ Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf(
+ Arc::new(self.clone()),
args,
- filter: None,
- order_by: None,
- })
+ false,
+ None,
+ None,
+ ))
}
/// Returns this function's name
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 7deb13c89b..7d126a0f33 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -291,7 +291,6 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::GroupingSet(_)
- | Expr::AggregateUDF { .. }
| Expr::InList { .. }
| Expr::Exists { .. }
| Expr::InSubquery(_)
@@ -595,15 +594,12 @@ pub fn group_window_expr_by_sort_keys(
Ok(result)
}
-/// Collect all deeply nested `Expr::AggregateFunction` and
-/// `Expr::AggregateUDF`. They are returned in order of occurrence (depth
+/// Collect all deeply nested `Expr::AggregateFunction`.
+/// They are returned in order of occurrence (depth
/// first), with duplicates omitted.
pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec<Expr> {
find_exprs_in_exprs(exprs, &|nested_expr| {
- matches!(
- nested_expr,
- Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. }
- )
+ matches!(nested_expr, Expr::AggregateFunction { .. })
})
}
diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
index b4de322f76..fd84bb8016 100644
--- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
+++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
@@ -19,7 +19,7 @@ use crate::analyzer::AnalyzerRule;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
use datafusion_common::Result;
-use datafusion_expr::expr::{AggregateFunction, InSubquery};
+use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery};
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::Expr::ScalarSubquery;
@@ -144,20 +144,23 @@ impl TreeNodeRewriter for CountWildcardRewriter {
_ => old_expr,
},
Expr::AggregateFunction(AggregateFunction {
- fun: aggregate_function::AggregateFunction::Count,
+ func_def:
+ AggregateFunctionDefinition::BuiltIn(
+ aggregate_function::AggregateFunction::Count,
+ ),
args,
distinct,
filter,
order_by,
}) if args.len() == 1 => match args[0] {
Expr::Wildcard { qualifier: None } => {
- Expr::AggregateFunction(AggregateFunction {
- fun: aggregate_function::AggregateFunction::Count,
- args: vec![lit(COUNT_STAR_EXPANSION)],
+ Expr::AggregateFunction(AggregateFunction::new(
+ aggregate_function::AggregateFunction::Count,
+ vec![lit(COUNT_STAR_EXPANSION)],
distinct,
filter,
order_by,
- })
+ ))
}
_ => old_expr,
},
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs
index eb5d8c53a5..bedc86e2f4 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -28,8 +28,8 @@ use datafusion_common::{
DataFusionError, Result, ScalarValue,
};
use datafusion_expr::expr::{
- self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction,
- WindowFunction,
+ self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList,
+ InSubquery, Like, ScalarFunction, WindowFunction,
};
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::expr_schema::cast_subquery;
@@ -346,39 +346,39 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
}
},
Expr::AggregateFunction(expr::AggregateFunction {
- fun,
+ func_def,
args,
distinct,
filter,
order_by,
- }) => {
- let new_expr = coerce_agg_exprs_for_signature(
- &fun,
- &args,
- &self.schema,
- &fun.signature(),
- )?;
- let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
- fun, new_expr, distinct, filter, order_by,
- ));
- Ok(expr)
- }
- Expr::AggregateUDF(expr::AggregateUDF {
- fun,
- args,
- filter,
- order_by,
- }) => {
- let new_expr = coerce_arguments_for_signature(
- args.as_slice(),
- &self.schema,
- fun.signature(),
- )?;
- let expr = Expr::AggregateUDF(expr::AggregateUDF::new(
- fun, new_expr, filter, order_by,
- ));
- Ok(expr)
- }
+ }) => match func_def {
+ AggregateFunctionDefinition::BuiltIn(fun) => {
+ let new_expr = coerce_agg_exprs_for_signature(
+ &fun,
+ &args,
+ &self.schema,
+ &fun.signature(),
+ )?;
+ let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
+ fun, new_expr, distinct, filter, order_by,
+ ));
+ Ok(expr)
+ }
+ AggregateFunctionDefinition::UDF(fun) => {
+ let new_expr = coerce_arguments_for_signature(
+ args.as_slice(),
+ &self.schema,
+ fun.signature(),
+ )?;
+ let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
+ fun, new_expr, false, filter, order_by,
+ ));
+ Ok(expr)
+ }
+ AggregateFunctionDefinition::Name(_) => {
+ internal_err!("Function `Expr` with name should be resolved.")
+ }
+ },
Expr::WindowFunction(WindowFunction {
fun,
args,
@@ -914,9 +914,10 @@ mod test {
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);
- let udaf = Expr::AggregateUDF(expr::AggregateUDF::new(
+ let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
Arc::new(my_avg),
vec![lit(10i64)],
+ false,
None,
None,
));
@@ -941,9 +942,10 @@ mod test {
&accumulator,
&state_type,
);
- let udaf = Expr::AggregateUDF(expr::AggregateUDF::new(
+ let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
Arc::new(my_avg),
vec![lit("10")],
+ false,
None,
None,
));
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index f5ad767c50..1d21407a69 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -509,10 +509,9 @@ enum ExprMask {
/// - [`Sort`](Expr::Sort)
/// - [`Wildcard`](Expr::Wildcard)
/// - [`AggregateFunction`](Expr::AggregateFunction)
- /// - [`AggregateUDF`](Expr::AggregateUDF)
Normal,
- /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction) and [`AggregateUDF`](Expr::AggregateUDF).
+ /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction).
NormalAndAggregates,
}
@@ -528,10 +527,7 @@ impl ExprMask {
| Expr::Wildcard { .. }
);
- let is_aggr = matches!(
- expr,
- Expr::AggregateFunction(..) | Expr::AggregateUDF { .. }
- );
+ let is_aggr = matches!(expr, Expr::AggregateFunction(..));
match self {
Self::Normal => is_normal_minus_aggregates || is_aggr,
@@ -908,7 +904,7 @@ mod test {
let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!());
let state_type: StateTypeFunction = Arc::new(|_| unimplemented!());
let udf_agg = |inner: Expr| {
- Expr::AggregateUDF(datafusion_expr::expr::AggregateUDF::new(
+ Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
Arc::new(AggregateUDF::new(
"my_agg",
&Signature::exact(vec![DataType::UInt32], Volatility::Stable),
@@ -917,6 +913,7 @@ mod test {
&state_type,
)),
vec![inner],
+ false,
None,
None,
))
diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs
index ed6f472186..b1000f042c 100644
--- a/datafusion/optimizer/src/decorrelate.rs
+++ b/datafusion/optimizer/src/decorrelate.rs
@@ -22,7 +22,7 @@ use datafusion_common::tree_node::{
};
use datafusion_common::{plan_err, Result};
use datafusion_common::{Column, DFSchemaRef, DataFusionError, ScalarValue};
-use datafusion_expr::expr::Alias;
+use datafusion_expr::expr::{AggregateFunctionDefinition, Alias};
use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction};
use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder};
use datafusion_physical_expr::execution_props::ExecutionProps;
@@ -372,16 +372,25 @@ fn agg_exprs_evaluation_result_on_empty_batch(
for e in agg_expr.iter() {
let result_expr = e.clone().transform_up(&|expr| {
let new_expr = match expr {
- Expr::AggregateFunction(expr::AggregateFunction { fun, .. }) => {
- if matches!(fun, datafusion_expr::AggregateFunction::Count) {
- Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some(0))))
- } else {
- Transformed::Yes(Expr::Literal(ScalarValue::Null))
+ Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => {
+ match func_def {
+ AggregateFunctionDefinition::BuiltIn(fun) => {
+ if matches!(fun, datafusion_expr::AggregateFunction::Count) {
+ Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some(
+ 0,
+ ))))
+ } else {
+ Transformed::Yes(Expr::Literal(ScalarValue::Null))
+ }
+ }
+ AggregateFunctionDefinition::UDF { .. } => {
+ Transformed::Yes(Expr::Literal(ScalarValue::Null))
+ }
+ AggregateFunctionDefinition::Name(_) => {
+ Transformed::Yes(Expr::Literal(ScalarValue::Null))
+ }
}
}
- Expr::AggregateUDF(_) => {
- Transformed::Yes(Expr::Literal(ScalarValue::Null))
- }
_ => Transformed::No(expr),
};
Ok(new_expr)
diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs
index 95eeee931b..bad6e24715 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -253,7 +253,6 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
Expr::Sort(_)
| Expr::AggregateFunction(_)
| Expr::WindowFunction(_)
- | Expr::AggregateUDF { .. }
| Expr::Wildcard { .. }
| Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
})?;
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 3310bfed75..c7366e1761 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -332,7 +332,6 @@ impl<'a> ConstEvaluator<'a> {
// Has no runtime cost, but needed during planning
Expr::Alias(..)
| Expr::AggregateFunction { .. }
- | Expr::AggregateUDF { .. }
| Expr::ScalarVariable(_, _)
| Expr::Column(_)
| Expr::OuterReferenceColumn(_, _)
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index fa142438c4..7e6fb6b355 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -23,6 +23,7 @@ use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::{DFSchema, Result};
+use datafusion_expr::expr::AggregateFunctionDefinition;
use datafusion_expr::{
aggregate_function::AggregateFunction::{Max, Min, Sum},
col,
@@ -70,7 +71,7 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result<bool> {
let mut aggregate_count = 0;
for expr in aggr_expr {
if let Expr::AggregateFunction(AggregateFunction {
- fun,
+ func_def: AggregateFunctionDefinition::BuiltIn(fun),
distinct,
args,
filter,
@@ -170,7 +171,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
.iter()
.map(|aggr_expr| match aggr_expr {
Expr::AggregateFunction(AggregateFunction {
- fun,
+ func_def: AggregateFunctionDefinition::BuiltIn(fun),
args,
distinct,
..
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs
index d596998c1d..ae3628bdde 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -1744,12 +1744,13 @@ pub fn parse_expr(
ExprType::AggregateUdfExpr(pb) => {
let agg_fn = registry.udaf(pb.fun_name.as_str())?;
- Ok(Expr::AggregateUDF(expr::AggregateUDF::new(
+ Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
agg_fn,
pb.args
.iter()
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, Error>>()?,
+ false,
parse_optional_expr(pb.filter.as_deref(), registry)?.map(Box::new),
parse_vec_expr(&pb.order_by, registry)?,
)))
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs
index 54be6460c3..b619339674 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -44,8 +44,9 @@ use datafusion_common::{
ScalarValue,
};
use datafusion_expr::expr::{
- self, Alias, Between, BinaryExpr, Cast, GetFieldAccess, GetIndexedField, GroupingSet,
- InList, Like, Placeholder, ScalarFunction, ScalarFunctionDefinition, Sort,
+ self, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Cast, GetFieldAccess,
+ GetIndexedField, GroupingSet, InList, Like, Placeholder, ScalarFunction,
+ ScalarFunctionDefinition, Sort,
};
use datafusion_expr::{
logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction,
@@ -652,104 +653,139 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
}
}
Expr::AggregateFunction(expr::AggregateFunction {
- ref fun,
+ ref func_def,
ref args,
ref distinct,
ref filter,
ref order_by,
}) => {
- let aggr_function = match fun {
- AggregateFunction::ApproxDistinct => {
- protobuf::AggregateFunction::ApproxDistinct
- }
- AggregateFunction::ApproxPercentileCont => {
- protobuf::AggregateFunction::ApproxPercentileCont
- }
- AggregateFunction::ApproxPercentileContWithWeight => {
- protobuf::AggregateFunction::ApproxPercentileContWithWeight
- }
- AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg,
- AggregateFunction::Min => protobuf::AggregateFunction::Min,
- AggregateFunction::Max => protobuf::AggregateFunction::Max,
- AggregateFunction::Sum => protobuf::AggregateFunction::Sum,
- AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd,
- AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr,
- AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor,
- AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd,
- AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr,
- AggregateFunction::Avg => protobuf::AggregateFunction::Avg,
- AggregateFunction::Count => protobuf::AggregateFunction::Count,
- AggregateFunction::Variance => protobuf::AggregateFunction::Variance,
- AggregateFunction::VariancePop => {
- protobuf::AggregateFunction::VariancePop
- }
- AggregateFunction::Covariance => {
- protobuf::AggregateFunction::Covariance
- }
- AggregateFunction::CovariancePop => {
- protobuf::AggregateFunction::CovariancePop
- }
- AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev,
- AggregateFunction::StddevPop => {
- protobuf::AggregateFunction::StddevPop
- }
- AggregateFunction::Correlation => {
- protobuf::AggregateFunction::Correlation
- }
- AggregateFunction::RegrSlope => {
- protobuf::AggregateFunction::RegrSlope
- }
- AggregateFunction::RegrIntercept => {
- protobuf::AggregateFunction::RegrIntercept
- }
- AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2,
- AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx,
- AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy,
- AggregateFunction::RegrCount => {
- protobuf::AggregateFunction::RegrCount
- }
- AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx,
- AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy,
- AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy,
- AggregateFunction::ApproxMedian => {
- protobuf::AggregateFunction::ApproxMedian
- }
- AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping,
- AggregateFunction::Median => protobuf::AggregateFunction::Median,
- AggregateFunction::FirstValue => {
- protobuf::AggregateFunction::FirstValueAgg
- }
- AggregateFunction::LastValue => {
- protobuf::AggregateFunction::LastValueAgg
- }
- AggregateFunction::StringAgg => {
- protobuf::AggregateFunction::StringAgg
+ match func_def {
+ AggregateFunctionDefinition::BuiltIn(fun) => {
+ let aggr_function = match fun {
+ AggregateFunction::ApproxDistinct => {
+ protobuf::AggregateFunction::ApproxDistinct
+ }
+ AggregateFunction::ApproxPercentileCont => {
+ protobuf::AggregateFunction::ApproxPercentileCont
+ }
+ AggregateFunction::ApproxPercentileContWithWeight => {
+ protobuf::AggregateFunction::ApproxPercentileContWithWeight
+ }
+ AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg,
+ AggregateFunction::Min => protobuf::AggregateFunction::Min,
+ AggregateFunction::Max => protobuf::AggregateFunction::Max,
+ AggregateFunction::Sum => protobuf::AggregateFunction::Sum,
+ AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd,
+ AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr,
+ AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor,
+ AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd,
+ AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr,
+ AggregateFunction::Avg => protobuf::AggregateFunction::Avg,
+ AggregateFunction::Count => protobuf::AggregateFunction::Count,
+ AggregateFunction::Variance => protobuf::AggregateFunction::Variance,
+ AggregateFunction::VariancePop => {
+ protobuf::AggregateFunction::VariancePop
+ }
+ AggregateFunction::Covariance => {
+ protobuf::AggregateFunction::Covariance
+ }
+ AggregateFunction::CovariancePop => {
+ protobuf::AggregateFunction::CovariancePop
+ }
+ AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev,
+ AggregateFunction::StddevPop => {
+ protobuf::AggregateFunction::StddevPop
+ }
+ AggregateFunction::Correlation => {
+ protobuf::AggregateFunction::Correlation
+ }
+ AggregateFunction::RegrSlope => {
+ protobuf::AggregateFunction::RegrSlope
+ }
+ AggregateFunction::RegrIntercept => {
+ protobuf::AggregateFunction::RegrIntercept
+ }
+ AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2,
+ AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx,
+ AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy,
+ AggregateFunction::RegrCount => {
+ protobuf::AggregateFunction::RegrCount
+ }
+ AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx,
+ AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy,
+ AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy,
+ AggregateFunction::ApproxMedian => {
+ protobuf::AggregateFunction::ApproxMedian
+ }
+ AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping,
+ AggregateFunction::Median => protobuf::AggregateFunction::Median,
+ AggregateFunction::FirstValue => {
+ protobuf::AggregateFunction::FirstValueAgg
+ }
+ AggregateFunction::LastValue => {
+ protobuf::AggregateFunction::LastValueAgg
+ }
+ AggregateFunction::StringAgg => {
+ protobuf::AggregateFunction::StringAgg
+ }
+ };
+
+ let aggregate_expr = protobuf::AggregateExprNode {
+ aggr_function: aggr_function.into(),
+ expr: args
+ .iter()
+ .map(|v| v.try_into())
+ .collect::<Result<Vec<_>, _>>()?,
+ distinct: *distinct,
+ filter: match filter {
+ Some(e) => Some(Box::new(e.as_ref().try_into()?)),
+ None => None,
+ },
+ order_by: match order_by {
+ Some(e) => e
+ .iter()
+ .map(|expr| expr.try_into())
+ .collect::<Result<Vec<_>, _>>()?,
+ None => vec![],
+ },
+ };
+ Self {
+ expr_type: Some(ExprType::AggregateExpr(Box::new(
+ aggregate_expr,
+ ))),
+ }
}
- };
-
- let aggregate_expr = protobuf::AggregateExprNode {
- aggr_function: aggr_function.into(),
- expr: args
- .iter()
- .map(|v| v.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- distinct: *distinct,
- filter: match filter {
- Some(e) => Some(Box::new(e.as_ref().try_into()?)),
- None => None,
- },
- order_by: match order_by {
- Some(e) => e
- .iter()
- .map(|expr| expr.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- None => vec![],
+ AggregateFunctionDefinition::UDF(fun) => Self {
+ expr_type: Some(ExprType::AggregateUdfExpr(Box::new(
+ protobuf::AggregateUdfExprNode {
+ fun_name: fun.name().to_string(),
+ args: args
+ .iter()
+ .map(|expr| expr.try_into())
+ .collect::<Result<Vec<_>, Error>>()?,
+ filter: match filter {
+ Some(e) => Some(Box::new(e.as_ref().try_into()?)),
+ None => None,
+ },
+ order_by: match order_by {
+ Some(e) => e
+ .iter()
+ .map(|expr| expr.try_into())
+ .collect::<Result<Vec<_>, _>>()?,
+ None => vec![],
+ },
+ },
+ ))),
},
- };
- Self {
- expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))),
+ AggregateFunctionDefinition::Name(_) => {
+ return Err(Error::NotImplemented(
+ "Proto serialization error: Trying to serialize a unresolved function"
+ .to_string(),
+ ));
+ }
}
}
+
Expr::ScalarVariable(_, _) => {
return Err(Error::General(
"Proto serialization error: Scalar Variable not supported"
@@ -790,34 +826,6 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
));
}
},
- Expr::AggregateUDF(expr::AggregateUDF {
- fun,
- args,
- filter,
- order_by,
- }) => Self {
- expr_type: Some(ExprType::AggregateUdfExpr(Box::new(
- protobuf::AggregateUdfExprNode {
- fun_name: fun.name().to_string(),
- args: args.iter().map(|expr| expr.try_into()).collect::<Result<
- Vec<_>,
- Error,
- >>(
- )?,
- filter: match filter {
- Some(e) => Some(Box::new(e.as_ref().try_into()?)),
- None => None,
- },
- order_by: match order_by {
- Some(e) => e
- .iter()
- .map(|expr| expr.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- None => vec![],
- },
- },
- ))),
- },
Expr::Not(expr) => {
let expr = Box::new(protobuf::Not {
expr: Some(Box::new(expr.as_ref().try_into()?)),
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 3ab001298e..45727c39a3 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -1375,9 +1375,10 @@ fn roundtrip_aggregate_udf() {
Arc::new(vec![DataType::Float64, DataType::UInt32]),
);
- let test_expr = Expr::AggregateUDF(expr::AggregateUDF::new(
+ let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
Arc::new(dummy_agg.clone()),
vec![lit(1.0_f64)],
+ false,
Some(Box::new(lit(true))),
None,
));
diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs
index 24ba4d1b50..958e038798 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -135,8 +135,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function
if let Some(fm) = self.context_provider.get_aggregate_meta(&name) {
let args = self.function_args_to_expr(args, schema, planner_context)?;
- return Ok(Expr::AggregateUDF(expr::AggregateUDF::new(
- fm, args, None, None,
+ return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
+ fm, args, false, None, None,
)));
}
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index 25fe6b6633..b8c130055a 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -34,6 +34,7 @@ use datafusion_common::{
internal_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, Result,
ScalarValue,
};
+use datafusion_expr::expr::AggregateFunctionDefinition;
use datafusion_expr::expr::InList;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{
@@ -706,7 +707,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
) -> Result<Expr> {
match self.sql_expr_to_logical_expr(expr, schema, planner_context)? {
Expr::AggregateFunction(expr::AggregateFunction {
- fun,
+ func_def: AggregateFunctionDefinition::BuiltIn(fun),
args,
distinct,
order_by,
diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs
index 356c536051..c546ca7552 100644
--- a/datafusion/sql/src/select.rs
+++ b/datafusion/sql/src/select.rs
@@ -170,11 +170,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
select_exprs
.iter()
.filter(|select_expr| match select_expr {
- Expr::AggregateFunction(_) | Expr::AggregateUDF(_) => false,
- Expr::Alias(Alias { expr, name: _, .. }) => !matches!(
- **expr,
- Expr::AggregateFunction(_) | Expr::AggregateUDF(_)
- ),
+ Expr::AggregateFunction(_) => false,
+ Expr::Alias(Alias { expr, name: _, .. }) => {
+ !matches!(**expr, Expr::AggregateFunction(_))
+ }
_ => true,
})
.cloned()
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs
index b7a51032dc..cf05d814a5 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -692,21 +692,14 @@ pub async fn from_substrait_agg_func(
// try udaf first, then built-in aggr fn.
if let Ok(fun) = ctx.udaf(function_name) {
- Ok(Arc::new(Expr::AggregateUDF(expr::AggregateUDF {
- fun,
- args,
- filter,
- order_by,
- })))
+ Ok(Arc::new(Expr::AggregateFunction(
+ expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by),
+ )))
} else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name)
{
- Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction {
- fun,
- args,
- distinct,
- filter,
- order_by,
- })))
+ Ok(Arc::new(Expr::AggregateFunction(
+ expr::AggregateFunction::new(fun, args, distinct, filter, order_by),
+ )))
} else {
not_impl_err!(
"Aggregated function {} is not supported: function anchor = {:?}",
diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs
index 2be3e7b4e8..d576e70711 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -33,8 +33,8 @@ use datafusion::common::{exec_err, internal_err, not_impl_err};
#[allow(unused_imports)]
use datafusion::logical_expr::aggregate_function;
use datafusion::logical_expr::expr::{
- Alias, BinaryExpr, Case, Cast, GroupingSet, InList, ScalarFunctionDefinition, Sort,
- WindowFunction,
+ AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, InList,
+ ScalarFunctionDefinition, Sort, WindowFunction,
};
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator};
use datafusion::prelude::Expr;
@@ -578,65 +578,73 @@ pub fn to_substrait_agg_measure(
),
) -> Result<Measure> {
match expr {
- Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by }) => {
- let sorts = if let Some(order_by) = order_by {
- order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::<Result<Vec<_>>>()?
- } else {
- vec![]
- };
- let mut arguments: Vec<FunctionArgument> = vec![];
- for arg in args {
- arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
- }
- let function_anchor = _register_function(fun.to_string(), extension_info);
- Ok(Measure {
- measure: Some(AggregateFunction {
- function_reference: function_anchor,
- arguments,
- sorts,
- output_type: None,
- invocation: match distinct {
- true => AggregationInvocation::Distinct as i32,
- false => AggregationInvocation::All as i32,
- },
- phase: AggregationPhase::Unspecified as i32,
- args: vec![],
- options: vec![],
- }),
- filter: match filter {
- Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?),
- None => None
+ Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by }) => {
+ match func_def {
+ AggregateFunctionDefinition::BuiltIn (fun) => {
+ let sorts = if let Some(order_by) = order_by {
+ order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::<Result<Vec<_>>>()?
+ } else {
+ vec![]
+ };
+ let mut arguments: Vec<FunctionArgument> = vec![];
+ for arg in args {
+ arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
+ }
+ let function_anchor = _register_function(fun.to_string(), extension_info);
+ Ok(Measure {
+ measure: Some(AggregateFunction {
+ function_reference: function_anchor,
+ arguments,
+ sorts,
+ output_type: None,
+ invocation: match distinct {
+ true => AggregationInvocation::Distinct as i32,
+ false => AggregationInvocation::All as i32,
+ },
+ phase: AggregationPhase::Unspecified as i32,
+ args: vec![],
+ options: vec![],
+ }),
+ filter: match filter {
+ Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?),
+ None => None
+ }
+ })
}
- })
- }
- Expr::AggregateUDF(expr::AggregateUDF{ fun, args, filter, order_by }) =>{
- let sorts = if let Some(order_by) = order_by {
- order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::<Result<Vec<_>>>()?
- } else {
- vec![]
- };
- let mut arguments: Vec<FunctionArgument> = vec![];
- for arg in args {
- arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
- }
- let function_anchor = _register_function(fun.name().to_string(), extension_info);
- Ok(Measure {
- measure: Some(AggregateFunction {
- function_reference: function_anchor,
- arguments,
- sorts,
- output_type: None,
- invocation: AggregationInvocation::All as i32,
- phase: AggregationPhase::Unspecified as i32,
- args: vec![],
- options: vec![],
- }),
- filter: match filter {
- Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?),
- None => None
+ AggregateFunctionDefinition::UDF(fun) => {
+ let sorts = if let Some(order_by) = order_by {
+ order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::<Result<Vec<_>>>()?
+ } else {
+ vec![]
+ };
+ let mut arguments: Vec<FunctionArgument> = vec![];
+ for arg in args {
+ arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
+ }
+ let function_anchor = _register_function(fun.name().to_string(), extension_info);
+ Ok(Measure {
+ measure: Some(AggregateFunction {
+ function_reference: function_anchor,
+ arguments,
+ sorts,
+ output_type: None,
+ invocation: AggregationInvocation::All as i32,
+ phase: AggregationPhase::Unspecified as i32,
+ args: vec![],
+ options: vec![],
+ }),
+ filter: match filter {
+ Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?),
+ None => None
+ }
+ })
}
- })
- },
+ AggregateFunctionDefinition::Name(name) => {
+ internal_err!("AggregateFunctionDefinition::Name({:?}) should be resolved during `AnalyzerRule`", name)
+ }
+ }
+
+ }
Expr::Alias(Alias{expr,..})=> {
to_substrait_agg_measure(expr, schema, extension_info)
}