You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/20 10:48:29 UTC
[arrow-datafusion] branch master updated: Refactor Expr::AggregateFunction and Expr::WindowFunction to use struct (#4671)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new fe477e488 Refactor Expr::AggregateFunction and Expr::WindowFunction to use struct (#4671)
fe477e488 is described below
commit fe477e48832adb59e8992d92f9a292128dfb0a97
Author: Jeffrey <22...@users.noreply.github.com>
AuthorDate: Tue Dec 20 21:48:23 2022 +1100
Refactor Expr::AggregateFunction and Expr::WindowFunction to use struct (#4671)
* Refactor Expr::WindowFunction to struct
* Refactor Expr::AggregateFunction to struct
* Fix
---
datafusion/core/src/dataframe.rs | 16 +--
datafusion/core/src/physical_plan/planner.rs | 25 ++--
datafusion/expr/src/expr.rs | 103 ++++++++++-----
datafusion/expr/src/expr_fn.rs | 122 ++++++++---------
datafusion/expr/src/expr_rewriter.rs | 23 ++--
datafusion/expr/src/expr_schema.rs | 8 +-
datafusion/expr/src/expr_visitor.rs | 8 +-
datafusion/expr/src/utils.rs | 144 ++++++++++-----------
datafusion/optimizer/src/push_down_projection.rs | 13 +-
.../optimizer/src/single_distinct_to_groupby.rs | 40 +++---
datafusion/optimizer/src/type_coercion.rs | 55 ++++----
datafusion/proto/src/logical_plan/from_proto.rs | 27 ++--
datafusion/proto/src/logical_plan/mod.rs | 94 +++++++-------
datafusion/proto/src/logical_plan/to_proto.rs | 10 +-
datafusion/sql/src/planner.rs | 35 +++--
datafusion/sql/src/utils.rs | 41 +++---
16 files changed, 407 insertions(+), 357 deletions(-)
diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 77f7e7615..df56935e4 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -807,7 +807,7 @@ mod tests {
use arrow::datatypes::DataType;
use datafusion_expr::{
- avg, cast, count, count_distinct, create_udf, lit, max, min, sum,
+ avg, cast, count, count_distinct, create_udf, expr, lit, max, min, sum,
BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame,
WindowFunction,
};
@@ -861,13 +861,13 @@ mod tests {
async fn select_with_window_exprs() -> Result<()> {
// build plan using Table API
let t = test_table().await?;
- let first_row = Expr::WindowFunction {
- fun: WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue),
- args: vec![col("aggregate_test_100.c1")],
- partition_by: vec![col("aggregate_test_100.c2")],
- order_by: vec![],
- window_frame: WindowFrame::new(false),
- };
+ let first_row = Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue),
+ vec![col("aggregate_test_100.c1")],
+ vec![col("aggregate_test_100.c2")],
+ vec![],
+ WindowFrame::new(false),
+ ));
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 548a1dd36..5d6f8c99c 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -62,7 +62,8 @@ use arrow::datatypes::{Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::expr::{
- self, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like, TryCast,
+ self, AggregateFunction, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet,
+ Like, TryCast, WindowFunction,
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan;
@@ -190,15 +191,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
Expr::ScalarUDF { fun, args, .. } => {
create_function_physical_name(&fun.name, false, args)
}
- Expr::WindowFunction { fun, args, .. } => {
+ Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
create_function_physical_name(&fun.to_string(), false, args)
}
- Expr::AggregateFunction {
+ Expr::AggregateFunction(AggregateFunction {
fun,
distinct,
args,
..
- } => create_function_physical_name(&fun.to_string(), *distinct, args),
+ }) => create_function_physical_name(&fun.to_string(), *distinct, args),
Expr::AggregateUDF { fun, args, filter } => {
if filter.is_some() {
return Err(DataFusionError::Execution(
@@ -547,18 +548,18 @@ impl DefaultPhysicalPlanner {
};
let get_sort_keys = |expr: &Expr| match expr {
- Expr::WindowFunction {
+ Expr::WindowFunction(WindowFunction{
ref partition_by,
ref order_by,
..
- } => generate_sort_key(partition_by, order_by),
+ }) => generate_sort_key(partition_by, order_by),
Expr::Alias(expr, _) => {
// Convert &Box<T> to &T
match &**expr {
- Expr::WindowFunction {
+ Expr::WindowFunction(WindowFunction{
ref partition_by,
ref order_by,
- ..} => generate_sort_key(partition_by, order_by),
+ ..}) => generate_sort_key(partition_by, order_by),
_ => unreachable!(),
}
}
@@ -1502,13 +1503,13 @@ pub fn create_window_expr_with_name(
) -> Result<Arc<dyn WindowExpr>> {
let name = name.into();
match e {
- Expr::WindowFunction {
+ Expr::WindowFunction(WindowFunction {
fun,
args,
partition_by,
order_by,
window_frame,
- } => {
+ }) => {
let args = args
.iter()
.map(|e| {
@@ -1608,12 +1609,12 @@ pub fn create_aggregate_expr_with_name(
execution_props: &ExecutionProps,
) -> Result<Arc<dyn AggregateExpr>> {
match e {
- Expr::AggregateFunction {
+ Expr::AggregateFunction(AggregateFunction {
fun,
distinct,
args,
..
- } => {
+ }) => {
let args = args
.iter()
.map(|e| {
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 42e6e105f..5c8700472 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -166,29 +166,9 @@ pub enum Expr {
args: Vec<Expr>,
},
/// Represents the call of an aggregate built-in function with arguments.
- AggregateFunction {
- /// Name of the function
- fun: aggregate_function::AggregateFunction,
- /// List of expressions to feed to the functions as arguments
- args: Vec<Expr>,
- /// Whether this is a DISTINCT aggregation or not
- distinct: bool,
- /// Optional filter
- filter: Option<Box<Expr>>,
- },
+ AggregateFunction(AggregateFunction),
/// Represents the call of a window function with arguments.
- WindowFunction {
- /// Name of the function
- fun: window_function::WindowFunction,
- /// List of expressions to feed to the functions as arguments
- args: Vec<Expr>,
- /// List of partition by expressions
- partition_by: Vec<Expr>,
- /// List of order by expressions
- order_by: Vec<Expr>,
- /// Window frame
- window_frame: window_frame::WindowFrame,
- },
+ WindowFunction(WindowFunction),
/// aggregate function
AggregateUDF {
/// The function
@@ -472,6 +452,69 @@ impl Sort {
}
}
+/// Aggregate function
+#[derive(Clone, PartialEq, Eq, Hash)]
+pub struct AggregateFunction {
+ /// Name of the function
+ pub fun: aggregate_function::AggregateFunction,
+ /// List of expressions to feed to the functions as arguments
+ pub args: Vec<Expr>,
+ /// Whether this is a DISTINCT aggregation or not
+ pub distinct: bool,
+ /// Optional filter
+ pub filter: Option<Box<Expr>>,
+}
+
+impl AggregateFunction {
+ pub fn new(
+ fun: aggregate_function::AggregateFunction,
+ args: Vec<Expr>,
+ distinct: bool,
+ filter: Option<Box<Expr>>,
+ ) -> Self {
+ Self {
+ fun,
+ args,
+ distinct,
+ filter,
+ }
+ }
+}
+
+/// Window function
+#[derive(Clone, PartialEq, Eq, Hash)]
+pub struct WindowFunction {
+ /// Name of the function
+ pub fun: window_function::WindowFunction,
+ /// List of expressions to feed to the functions as arguments
+ pub args: Vec<Expr>,
+ /// List of partition by expressions
+ pub partition_by: Vec<Expr>,
+ /// List of order by expressions
+ pub order_by: Vec<Expr>,
+ /// Window frame
+ pub window_frame: window_frame::WindowFrame,
+}
+
+impl WindowFunction {
+ /// Create a new Window expression
+ pub fn new(
+ fun: window_function::WindowFunction,
+ args: Vec<Expr>,
+ partition_by: Vec<Expr>,
+ order_by: Vec<Expr>,
+ window_frame: window_frame::WindowFrame,
+ ) -> Self {
+ Self {
+ fun,
+ args,
+ partition_by,
+ order_by,
+ window_frame,
+ }
+ }
+}
+
/// Grouping sets
/// See https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS
/// for Postgres definition.
@@ -867,13 +910,13 @@ impl fmt::Debug for Expr {
Expr::ScalarUDF { fun, ref args, .. } => {
fmt_function(f, &fun.name, false, args, false)
}
- Expr::WindowFunction {
+ Expr::WindowFunction(WindowFunction {
fun,
args,
partition_by,
order_by,
window_frame,
- } => {
+ }) => {
fmt_function(f, &fun.to_string(), false, args, false)?;
if !partition_by.is_empty() {
write!(f, " PARTITION BY {:?}", partition_by)?;
@@ -888,13 +931,13 @@ impl fmt::Debug for Expr {
)?;
Ok(())
}
- Expr::AggregateFunction {
+ Expr::AggregateFunction(AggregateFunction {
fun,
distinct,
ref args,
filter,
..
- } => {
+ }) => {
fmt_function(f, &fun.to_string(), *distinct, args, true)?;
if let Some(fe) = filter {
write!(f, " FILTER (WHERE {})", fe)?;
@@ -1223,13 +1266,13 @@ fn create_name(e: &Expr) -> Result<String> {
create_function_name(&fun.to_string(), false, args)
}
Expr::ScalarUDF { fun, args, .. } => create_function_name(&fun.name, false, args),
- Expr::WindowFunction {
+ Expr::WindowFunction(WindowFunction {
fun,
args,
window_frame,
partition_by,
order_by,
- } => {
+ }) => {
let mut parts: Vec<String> =
vec![create_function_name(&fun.to_string(), false, args)?];
if !partition_by.is_empty() {
@@ -1241,12 +1284,12 @@ fn create_name(e: &Expr) -> Result<String> {
parts.push(format!("{}", window_frame));
Ok(parts.join(" "))
}
- Expr::AggregateFunction {
+ Expr::AggregateFunction(AggregateFunction {
fun,
distinct,
args,
filter,
- } => {
+ }) => {
let name = create_function_name(&fun.to_string(), *distinct, args)?;
if let Some(fe) = filter {
Ok(format!("{} FILTER (WHERE {})", name, fe))
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index d305013f1..fbb406ad2 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -17,7 +17,7 @@
//! Functions for creating logical expressions
-use crate::expr::{BinaryExpr, Cast, GroupingSet, TryCast};
+use crate::expr::{AggregateFunction, BinaryExpr, Cast, GroupingSet, TryCast};
use crate::{
aggregate_function, built_in_function, conditional_expressions::CaseBuilder,
logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF,
@@ -64,62 +64,62 @@ pub fn or(left: Expr, right: Expr) -> Expr {
/// Create an expression to represent the min() aggregate function
pub fn min(expr: Expr) -> Expr {
- Expr::AggregateFunction {
- fun: aggregate_function::AggregateFunction::Min,
- distinct: false,
- args: vec![expr],
- filter: None,
- }
+ Expr::AggregateFunction(AggregateFunction::new(
+ aggregate_function::AggregateFunction::Min,
+ vec![expr],
+ false,
+ None,
+ ))
}
/// Create an expression to represent the max() aggregate function
pub fn max(expr: Expr) -> Expr {
- Expr::AggregateFunction {
- fun: aggregate_function::AggregateFunction::Max,
- distinct: false,
- args: vec![expr],
- filter: None,
- }
+ Expr::AggregateFunction(AggregateFunction::new(
+ aggregate_function::AggregateFunction::Max,
+ vec![expr],
+ false,
+ None,
+ ))
}
/// Create an expression to represent the sum() aggregate function
pub fn sum(expr: Expr) -> Expr {
- Expr::AggregateFunction {
- fun: aggregate_function::AggregateFunction::Sum,
- distinct: false,
- args: vec![expr],
- filter: None,
- }
+ Expr::AggregateFunction(AggregateFunction::new(
+ aggregate_function::AggregateFunction::Sum,
+ vec![expr],
+ false,
+ None,
+ ))
}
/// Create an expression to represent the avg() aggregate function
pub fn avg(expr: Expr) -> Expr {
- Expr::AggregateFunction {
- fun: aggregate_function::AggregateFunction::Avg,
- distinct: false,
- args: vec![expr],
- filter: None,
- }
+ Expr::AggregateFunction(AggregateFunction::new(
+ aggregate_function::AggregateFunction::Avg,
+ vec![expr],
+ false,
+ None,
+ ))
}
/// Create an expression to represent the count() aggregate function
pub fn count(expr: Expr) -> Expr {
- Expr::AggregateFunction {
- fun: aggregate_function::AggregateFunction::Count,
- distinct: false,
- args: vec![expr],
- filter: None,
- }
+ Expr::AggregateFunction(AggregateFunction::new(
+ aggregate_function::AggregateFunction::Count,
+ vec![expr],
+ false,
+ None,
+ ))
}
/// Create an expression to represent the count(distinct) aggregate function
pub fn count_distinct(expr: Expr) -> Expr {
- Expr::AggregateFunction {
- fun: aggregate_function::AggregateFunction::Count,
- distinct: true,
- args: vec![expr],
- filter: None,
- }
+ Expr::AggregateFunction(AggregateFunction::new(
+ aggregate_function::AggregateFunction::Count,
+ vec![expr],
+ true,
+ None,
+ ))
}
/// Create an in_list expression
@@ -167,32 +167,32 @@ pub fn random() -> Expr {
/// error distribution over all possible sets.
/// It does not guarantee an upper bound on the error for any specific input set.
pub fn approx_distinct(expr: Expr) -> Expr {
- Expr::AggregateFunction {
- fun: aggregate_function::AggregateFunction::ApproxDistinct,
- distinct: false,
- args: vec![expr],
- filter: None,
- }
+ Expr::AggregateFunction(AggregateFunction::new(
+ aggregate_function::AggregateFunction::ApproxDistinct,
+ vec![expr],
+ false,
+ None,
+ ))
}
/// Calculate an approximation of the median for `expr`.
pub fn approx_median(expr: Expr) -> Expr {
- Expr::AggregateFunction {
- fun: aggregate_function::AggregateFunction::ApproxMedian,
- distinct: false,
- args: vec![expr],
- filter: None,
- }
+ Expr::AggregateFunction(AggregateFunction::new(
+ aggregate_function::AggregateFunction::ApproxMedian,
+ vec![expr],
+ false,
+ None,
+ ))
}
/// Calculate an approximation of the specified `percentile` for `expr`.
pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
- Expr::AggregateFunction {
- fun: aggregate_function::AggregateFunction::ApproxPercentileCont,
- distinct: false,
- args: vec![expr, percentile],
- filter: None,
- }
+ Expr::AggregateFunction(AggregateFunction::new(
+ aggregate_function::AggregateFunction::ApproxPercentileCont,
+ vec![expr, percentile],
+ false,
+ None,
+ ))
}
/// Calculate an approximation of the specified `percentile` for `expr` and `weight_expr`.
@@ -201,12 +201,12 @@ pub fn approx_percentile_cont_with_weight(
weight_expr: Expr,
percentile: Expr,
) -> Expr {
- Expr::AggregateFunction {
- fun: aggregate_function::AggregateFunction::ApproxPercentileContWithWeight,
- distinct: false,
- args: vec![expr, weight_expr, percentile],
- filter: None,
- }
+ Expr::AggregateFunction(AggregateFunction::new(
+ aggregate_function::AggregateFunction::ApproxPercentileContWithWeight,
+ vec![expr, weight_expr, percentile],
+ false,
+ None,
+ ))
}
/// Create an EXISTS subquery expression
diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs
index 8f90d4ca8..6ad7b5eec 100644
--- a/datafusion/expr/src/expr_rewriter.rs
+++ b/datafusion/expr/src/expr_rewriter.rs
@@ -18,7 +18,8 @@
//! Expression rewriter
use crate::expr::{
- Between, BinaryExpr, Case, Cast, GetIndexedField, GroupingSet, Like, Sort, TryCast,
+ AggregateFunction, Between, BinaryExpr, Case, Cast, GetIndexedField, GroupingSet,
+ Like, Sort, TryCast, WindowFunction,
};
use crate::logical_plan::{Aggregate, Projection};
use crate::utils::{from_plan, grouping_set_to_exprlist};
@@ -224,30 +225,30 @@ impl ExprRewritable for Expr {
args: rewrite_vec(args, rewriter)?,
fun,
},
- Expr::WindowFunction {
+ Expr::WindowFunction(WindowFunction {
args,
fun,
partition_by,
order_by,
window_frame,
- } => Expr::WindowFunction {
- args: rewrite_vec(args, rewriter)?,
+ }) => Expr::WindowFunction(WindowFunction::new(
fun,
- partition_by: rewrite_vec(partition_by, rewriter)?,
- order_by: rewrite_vec(order_by, rewriter)?,
+ rewrite_vec(args, rewriter)?,
+ rewrite_vec(partition_by, rewriter)?,
+ rewrite_vec(order_by, rewriter)?,
window_frame,
- },
- Expr::AggregateFunction {
+ )),
+ Expr::AggregateFunction(AggregateFunction {
args,
fun,
distinct,
filter,
- } => Expr::AggregateFunction {
- args: rewrite_vec(args, rewriter)?,
+ }) => Expr::AggregateFunction(AggregateFunction::new(
fun,
+ rewrite_vec(args, rewriter)?,
distinct,
filter,
- },
+ )),
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
Expr::GroupingSet(GroupingSet::Rollup(rewrite_vec(exprs, rewriter)?))
diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs
index 2fcd673d0..c1a625cf4 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -16,7 +16,9 @@
// under the License.
use super::{Between, Expr, Like};
-use crate::expr::{BinaryExpr, Cast, GetIndexedField, Sort, TryCast};
+use crate::expr::{
+ AggregateFunction, BinaryExpr, Cast, GetIndexedField, Sort, TryCast, WindowFunction,
+};
use crate::field_util::get_indexed_field;
use crate::type_coercion::binary::binary_operator_data_type;
use crate::{aggregate_function, function, window_function};
@@ -77,14 +79,14 @@ impl ExprSchemable for Expr {
.collect::<Result<Vec<_>>>()?;
function::return_type(fun, &data_types)
}
- Expr::WindowFunction { fun, args, .. } => {
+ Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
window_function::return_type(fun, &data_types)
}
- Expr::AggregateFunction { fun, args, .. } => {
+ Expr::AggregateFunction(AggregateFunction { fun, args, .. }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
diff --git a/datafusion/expr/src/expr_visitor.rs b/datafusion/expr/src/expr_visitor.rs
index 2e574e587..ed80547f1 100644
--- a/datafusion/expr/src/expr_visitor.rs
+++ b/datafusion/expr/src/expr_visitor.rs
@@ -17,7 +17,7 @@
//! Expression visitor
-use crate::expr::{Cast, Sort};
+use crate::expr::{AggregateFunction, Cast, Sort, WindowFunction};
use crate::{
expr::{BinaryExpr, GroupingSet, TryCast},
Between, Expr, GetIndexedField, Like,
@@ -180,7 +180,7 @@ impl ExprVisitable for Expr {
Expr::ScalarFunction { args, .. } | Expr::ScalarUDF { args, .. } => args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
- Expr::AggregateFunction { args, filter, .. }
+ Expr::AggregateFunction(AggregateFunction { args, filter, .. })
| Expr::AggregateUDF { args, filter, .. } => {
if let Some(f) = filter {
let mut aggr_exprs = args.clone();
@@ -193,12 +193,12 @@ impl ExprVisitable for Expr {
.try_fold(visitor, |visitor, arg| arg.accept(visitor))
}
}
- Expr::WindowFunction {
+ Expr::WindowFunction(WindowFunction {
args,
partition_by,
order_by,
..
- } => {
+ }) => {
let visitor = args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 2577c3a19..89229a3d4 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -17,7 +17,7 @@
//! Expression utilities
-use crate::expr::Sort;
+use crate::expr::{Sort, WindowFunction};
use crate::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use crate::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
use crate::logical_plan::builder::build_join_schema;
@@ -315,7 +315,7 @@ pub fn group_window_expr_by_sort_keys(
) -> Result<Vec<(WindowSortKey, Vec<&Expr>)>> {
let mut result = vec![];
window_expr.iter().try_for_each(|expr| match expr {
- Expr::WindowFunction { partition_by, order_by, .. } => {
+ Expr::WindowFunction(WindowFunction{ partition_by, order_by, .. }) => {
let sort_key = generate_sort_key(partition_by, order_by)?;
if let Some((_, values)) = result.iter_mut().find(
|group: &&mut (WindowSortKey, Vec<&Expr>)| matches!(group, (key, _) if *key == sort_key),
@@ -952,34 +952,34 @@ mod tests {
#[test]
fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
- let max1 = Expr::WindowFunction {
- fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
- args: vec![col("name")],
- partition_by: vec![],
- order_by: vec![],
- window_frame: WindowFrame::new(false),
- };
- let max2 = Expr::WindowFunction {
- fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
- args: vec![col("name")],
- partition_by: vec![],
- order_by: vec![],
- window_frame: WindowFrame::new(false),
- };
- let min3 = Expr::WindowFunction {
- fun: WindowFunction::AggregateFunction(AggregateFunction::Min),
- args: vec![col("name")],
- partition_by: vec![],
- order_by: vec![],
- window_frame: WindowFrame::new(false),
- };
- let sum4 = Expr::WindowFunction {
- fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
- args: vec![col("age")],
- partition_by: vec![],
- order_by: vec![],
- window_frame: WindowFrame::new(false),
- };
+ let max1 = Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::AggregateFunction(AggregateFunction::Max),
+ vec![col("name")],
+ vec![],
+ vec![],
+ WindowFrame::new(false),
+ ));
+ let max2 = Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::AggregateFunction(AggregateFunction::Max),
+ vec![col("name")],
+ vec![],
+ vec![],
+ WindowFrame::new(false),
+ ));
+ let min3 = Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::AggregateFunction(AggregateFunction::Min),
+ vec![col("name")],
+ vec![],
+ vec![],
+ WindowFrame::new(false),
+ ));
+ let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::AggregateFunction(AggregateFunction::Sum),
+ vec![col("age")],
+ vec![],
+ vec![],
+ WindowFrame::new(false),
+ ));
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs)?;
let key = vec![];
@@ -995,34 +995,34 @@ mod tests {
let name_desc = Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true));
let created_at_desc =
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true));
- let max1 = Expr::WindowFunction {
- fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
- args: vec![col("name")],
- partition_by: vec![],
- order_by: vec![age_asc.clone(), name_desc.clone()],
- window_frame: WindowFrame::new(true),
- };
- let max2 = Expr::WindowFunction {
- fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
- args: vec![col("name")],
- partition_by: vec![],
- order_by: vec![],
- window_frame: WindowFrame::new(false),
- };
- let min3 = Expr::WindowFunction {
- fun: WindowFunction::AggregateFunction(AggregateFunction::Min),
- args: vec![col("name")],
- partition_by: vec![],
- order_by: vec![age_asc.clone(), name_desc.clone()],
- window_frame: WindowFrame::new(true),
- };
- let sum4 = Expr::WindowFunction {
- fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
- args: vec![col("age")],
- partition_by: vec![],
- order_by: vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
- window_frame: WindowFrame::new(true),
- };
+ let max1 = Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::AggregateFunction(AggregateFunction::Max),
+ vec![col("name")],
+ vec![],
+ vec![age_asc.clone(), name_desc.clone()],
+ WindowFrame::new(true),
+ ));
+ let max2 = Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::AggregateFunction(AggregateFunction::Max),
+ vec![col("name")],
+ vec![],
+ vec![],
+ WindowFrame::new(false),
+ ));
+ let min3 = Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::AggregateFunction(AggregateFunction::Min),
+ vec![col("name")],
+ vec![],
+ vec![age_asc.clone(), name_desc.clone()],
+ WindowFrame::new(true),
+ ));
+ let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::AggregateFunction(AggregateFunction::Sum),
+ vec![col("age")],
+ vec![],
+ vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
+ WindowFrame::new(true),
+ ));
// FIXME use as_ref
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs)?;
@@ -1043,27 +1043,27 @@ mod tests {
#[test]
fn test_find_sort_exprs() -> Result<()> {
let exprs = &[
- Expr::WindowFunction {
- fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
- args: vec![col("name")],
- partition_by: vec![],
- order_by: vec![
+ Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::AggregateFunction(AggregateFunction::Max),
+ vec![col("name")],
+ vec![],
+ vec![
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
],
- window_frame: WindowFrame::new(true),
- },
- Expr::WindowFunction {
- fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
- args: vec![col("age")],
- partition_by: vec![],
- order_by: vec![
+ WindowFrame::new(true),
+ )),
+ Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::AggregateFunction(AggregateFunction::Sum),
+ vec![col("age")],
+ vec![],
+ vec![
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
],
- window_frame: WindowFrame::new(true),
- },
+ WindowFrame::new(true),
+ )),
];
let expected = vec![
Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)),
diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs
index fb8e241fd..11861ae3e 100644
--- a/datafusion/optimizer/src/push_down_projection.rs
+++ b/datafusion/optimizer/src/push_down_projection.rs
@@ -524,6 +524,7 @@ mod tests {
use crate::test::*;
use crate::OptimizerContext;
use arrow::datatypes::{DataType, Schema};
+ use datafusion_expr::expr;
use datafusion_expr::expr::Cast;
use datafusion_expr::{
col, count, lit,
@@ -985,12 +986,12 @@ mod tests {
fn aggregate_filter_pushdown() -> Result<()> {
let table_scan = test_table_scan()?;
- let aggr_with_filter = Expr::AggregateFunction {
- fun: AggregateFunction::Count,
- args: vec![col("b")],
- distinct: false,
- filter: Some(Box::new(col("c").gt(lit(42)))),
- };
+ let aggr_with_filter = Expr::AggregateFunction(expr::AggregateFunction::new(
+ AggregateFunction::Count,
+ vec![col("b")],
+ false,
+ Some(Box::new(col("c").gt(lit(42)))),
+ ));
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index 4ba2e1599..bf4231c1f 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -21,6 +21,7 @@ use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::{DFSchema, Result};
use datafusion_expr::{
col,
+ expr::AggregateFunction,
logical_plan::{Aggregate, LogicalPlan, Projection},
utils::columnize_expr,
Expr, ExprSchemable,
@@ -61,7 +62,10 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result<bool> {
let mut fields_set = HashSet::new();
let mut distinct_count = 0;
for expr in aggr_expr {
- if let Expr::AggregateFunction { distinct, args, .. } = expr {
+ if let Expr::AggregateFunction(AggregateFunction {
+ distinct, args, ..
+ }) = expr
+ {
if *distinct {
distinct_count += 1;
}
@@ -121,21 +125,24 @@ impl OptimizerRule for SingleDistinctToGroupBy {
let new_aggr_exprs = aggr_expr
.iter()
.map(|aggr_expr| match aggr_expr {
- Expr::AggregateFunction {
- fun, args, filter, ..
- } => {
+ Expr::AggregateFunction(AggregateFunction {
+ fun,
+ args,
+ filter,
+ ..
+ }) => {
// is_single_distinct_agg ensure args.len=1
if group_fields_set.insert(args[0].display_name()?) {
inner_group_exprs.push(
args[0].clone().alias(SINGLE_DISTINCT_ALIAS),
);
}
- Ok(Expr::AggregateFunction {
- fun: fun.clone(),
- args: vec![col(SINGLE_DISTINCT_ALIAS)],
- distinct: false, // intentional to remove distinct here
- filter: filter.clone(),
- })
+ Ok(Expr::AggregateFunction(AggregateFunction::new(
+ fun.clone(),
+ vec![col(SINGLE_DISTINCT_ALIAS)],
+ false, // intentional to remove distinct here
+ filter.clone(),
+ )))
}
_ => Ok(aggr_expr.clone()),
})
@@ -216,6 +223,7 @@ mod tests {
use super::*;
use crate::test::*;
use crate::OptimizerContext;
+ use datafusion_expr::expr;
use datafusion_expr::expr::GroupingSet;
use datafusion_expr::{
col, count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max,
@@ -392,12 +400,12 @@ mod tests {
vec![col("a")],
vec![
count_distinct(col("b")),
- Expr::AggregateFunction {
- fun: AggregateFunction::Max,
- distinct: true,
- args: vec![col("b")],
- filter: None,
- },
+ Expr::AggregateFunction(expr::AggregateFunction::new(
+ AggregateFunction::Max,
+ vec![col("b")],
+ true,
+ None,
+ )),
],
)?
.build()?;
diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs
index d1655fe35..a957bb41b 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -24,7 +24,7 @@ use arrow::datatypes::{DataType, IntervalUnit};
use datafusion_common::{
parse_interval, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
};
-use datafusion_expr::expr::{Between, BinaryExpr, Case, Like};
+use datafusion_expr::expr::{self, Between, BinaryExpr, Case, Like, WindowFunction};
use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion};
use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::type_coercion::binary::{coerce_types, comparison_coercion};
@@ -376,24 +376,21 @@ impl ExprRewriter for TypeCoercionRewriter {
};
Ok(expr)
}
- Expr::AggregateFunction {
+ Expr::AggregateFunction(expr::AggregateFunction {
fun,
args,
distinct,
filter,
- } => {
+ }) => {
let new_expr = coerce_agg_exprs_for_signature(
&fun,
&args,
&self.schema,
&aggregate_function::signature(&fun),
)?;
- let expr = Expr::AggregateFunction {
- fun,
- args: new_expr,
- distinct,
- filter,
- };
+ let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
+ fun, new_expr, distinct, filter,
+ ));
Ok(expr)
}
Expr::AggregateUDF { fun, args, filter } => {
@@ -409,22 +406,22 @@ impl ExprRewriter for TypeCoercionRewriter {
};
Ok(expr)
}
- Expr::WindowFunction {
+ Expr::WindowFunction(WindowFunction {
fun,
args,
partition_by,
order_by,
window_frame,
- } => {
+ }) => {
let window_frame =
get_coerced_window_frame(window_frame, &self.schema, &order_by)?;
- let expr = Expr::WindowFunction {
+ let expr = Expr::WindowFunction(WindowFunction::new(
fun,
args,
partition_by,
order_by,
window_frame,
- };
+ ));
Ok(expr)
}
expr => Ok(expr),
@@ -592,7 +589,7 @@ mod test {
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
- use datafusion_expr::expr::Like;
+ use datafusion_expr::expr::{self, Like};
use datafusion_expr::expr_rewriter::ExprRewritable;
use datafusion_expr::{
cast, col, concat, concat_ws, create_udaf, is_true,
@@ -776,24 +773,24 @@ mod test {
fn agg_function_case() -> Result<()> {
let empty = empty();
let fun: AggregateFunction = AggregateFunction::Avg;
- let agg_expr = Expr::AggregateFunction {
+ let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
fun,
- args: vec![lit(12i64)],
- distinct: false,
- filter: None,
- };
+ vec![lit(12i64)],
+ false,
+ None,
+ ));
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
let expected = "Projection: AVG(Int64(12))\n EmptyRelation";
assert_optimized_plan_eq(&plan, expected)?;
let empty = empty_with_type(DataType::Int32);
let fun: AggregateFunction = AggregateFunction::Avg;
- let agg_expr = Expr::AggregateFunction {
+ let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
fun,
- args: vec![col("a")],
- distinct: false,
- filter: None,
- };
+ vec![col("a")],
+ false,
+ None,
+ ));
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
let expected = "Projection: AVG(a)\n EmptyRelation";
assert_optimized_plan_eq(&plan, expected)?;
@@ -804,12 +801,12 @@ mod test {
fn agg_function_invalid_input() -> Result<()> {
let empty = empty();
let fun: AggregateFunction = AggregateFunction::Avg;
- let agg_expr = Expr::AggregateFunction {
+ let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
fun,
- args: vec![lit("1")],
- distinct: false,
- filter: None,
- };
+ vec![lit("1")],
+ false,
+ None,
+ ));
let err = Projection::try_new(vec![agg_expr], empty).err().unwrap();
assert_eq!(
"Plan(\"The function Avg does not support inputs of type Utf8.\")",
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs
index 96981aff3..9ba617116 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -36,7 +36,7 @@ use datafusion_expr::{
abs, acos, array, ascii, asin, atan, atan2, bit_length, btrim, ceil,
character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, date_bin,
date_part, date_trunc, digest, exp,
- expr::Sort,
+ expr::{self, Sort, WindowFunction},
floor, from_unixtime, left, ln, log10, log2,
logical_plan::{PlanType, StringifiedPlan},
lower, lpad, ltrim, md5, now, nullif, octet_length, power, random, regexp_match,
@@ -924,15 +924,15 @@ pub fn parse_expr(
window_expr_node::WindowFunction::AggrFunction(i) => {
let aggr_function = parse_i32_to_aggregate_function(i)?;
- Ok(Expr::WindowFunction {
- fun: datafusion_expr::window_function::WindowFunction::AggregateFunction(
+ Ok(Expr::WindowFunction(WindowFunction::new(
+ datafusion_expr::window_function::WindowFunction::AggregateFunction(
aggr_function,
),
- args: vec![parse_required_expr(&expr.expr, registry, "expr")?],
+ vec![parse_required_expr(&expr.expr, registry, "expr")?],
partition_by,
order_by,
window_frame,
- })
+ )))
}
window_expr_node::WindowFunction::BuiltInFunction(i) => {
let built_in_function = protobuf::BuiltInWindowFunction::from_i32(*i)
@@ -943,31 +943,30 @@ pub fn parse_expr(
.map(|e| vec![e])
.unwrap_or_else(Vec::new);
- Ok(Expr::WindowFunction {
- fun: datafusion_expr::window_function::WindowFunction::BuiltInWindowFunction(
+ Ok(Expr::WindowFunction(WindowFunction::new(
+ datafusion_expr::window_function::WindowFunction::BuiltInWindowFunction(
built_in_function,
),
args,
partition_by,
order_by,
window_frame,
- })
+ )))
}
}
}
ExprType::AggregateExpr(expr) => {
let fun = parse_i32_to_aggregate_function(&expr.aggr_function)?;
- Ok(Expr::AggregateFunction {
+ Ok(Expr::AggregateFunction(expr::AggregateFunction::new(
fun,
- args: expr
- .expr
+ expr.expr
.iter()
.map(|e| parse_expr(e, registry))
.collect::<Result<Vec<_>, _>>()?,
- distinct: expr.distinct,
- filter: parse_optional_expr(&expr.filter, registry)?.map(Box::new),
- })
+ expr.distinct,
+ parse_optional_expr(&expr.filter, registry)?.map(Box::new),
+ )))
}
ExprType::Alias(alias) => Ok(Expr::Alias(
Box::new(parse_required_expr(&alias.expr, registry, "expr")?),
diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs
index d099f4022..ad92c68a1 100644
--- a/datafusion/proto/src/logical_plan/mod.rs
+++ b/datafusion/proto/src/logical_plan/mod.rs
@@ -1387,7 +1387,7 @@ mod roundtrip_tests {
use datafusion::test_util::{TestTableFactory, TestTableProvider};
use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
use datafusion_expr::expr::{
- Between, BinaryExpr, Case, Cast, GroupingSet, Like, Sort,
+ self, Between, BinaryExpr, Case, Cast, GroupingSet, Like, Sort,
};
use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode};
use datafusion_expr::{
@@ -2484,36 +2484,36 @@ mod roundtrip_tests {
#[test]
fn roundtrip_count() {
- let test_expr = Expr::AggregateFunction {
- fun: AggregateFunction::Count,
- args: vec![col("bananas")],
- distinct: false,
- filter: None,
- };
+ let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
+ AggregateFunction::Count,
+ vec![col("bananas")],
+ false,
+ None,
+ ));
let ctx = SessionContext::new();
roundtrip_expr_test(test_expr, ctx);
}
#[test]
fn roundtrip_count_distinct() {
- let test_expr = Expr::AggregateFunction {
- fun: AggregateFunction::Count,
- args: vec![col("bananas")],
- distinct: true,
- filter: None,
- };
+ let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
+ AggregateFunction::Count,
+ vec![col("bananas")],
+ true,
+ None,
+ ));
let ctx = SessionContext::new();
roundtrip_expr_test(test_expr, ctx);
}
#[test]
fn roundtrip_approx_percentile_cont() {
- let test_expr = Expr::AggregateFunction {
- fun: AggregateFunction::ApproxPercentileCont,
- args: vec![col("bananas"), lit(0.42_f32)],
- distinct: false,
- filter: None,
- };
+ let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
+ AggregateFunction::ApproxPercentileCont,
+ vec![col("bananas"), lit(0.42_f32)],
+ false,
+ None,
+ ));
let ctx = SessionContext::new();
roundtrip_expr_test(test_expr, ctx);
@@ -2654,26 +2654,26 @@ mod roundtrip_tests {
let ctx = SessionContext::new();
// 1. without window_frame
- let test_expr1 = Expr::WindowFunction {
- fun: WindowFunction::BuiltInWindowFunction(
+ let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::BuiltInWindowFunction(
datafusion_expr::window_function::BuiltInWindowFunction::Rank,
),
- args: vec![],
- partition_by: vec![col("col1")],
- order_by: vec![col("col2")],
- window_frame: WindowFrame::new(true),
- };
+ vec![],
+ vec![col("col1")],
+ vec![col("col2")],
+ WindowFrame::new(true),
+ ));
// 2. with default window_frame
- let test_expr2 = Expr::WindowFunction {
- fun: WindowFunction::BuiltInWindowFunction(
+ let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::BuiltInWindowFunction(
datafusion_expr::window_function::BuiltInWindowFunction::Rank,
),
- args: vec![],
- partition_by: vec![col("col1")],
- order_by: vec![col("col2")],
- window_frame: WindowFrame::new(true),
- };
+ vec![],
+ vec![col("col1")],
+ vec![col("col2")],
+ WindowFrame::new(true),
+ ));
// 3. with window_frame with row numbers
let range_number_frame = WindowFrame {
@@ -2682,15 +2682,15 @@ mod roundtrip_tests {
end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
};
- let test_expr3 = Expr::WindowFunction {
- fun: WindowFunction::BuiltInWindowFunction(
+ let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::BuiltInWindowFunction(
datafusion_expr::window_function::BuiltInWindowFunction::Rank,
),
- args: vec![],
- partition_by: vec![col("col1")],
- order_by: vec![col("col2")],
- window_frame: range_number_frame,
- };
+ vec![],
+ vec![col("col1")],
+ vec![col("col2")],
+ range_number_frame,
+ ));
// 4. test with AggregateFunction
let row_number_frame = WindowFrame {
@@ -2699,13 +2699,13 @@ mod roundtrip_tests {
end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))),
};
- let test_expr4 = Expr::WindowFunction {
- fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
- args: vec![col("col1")],
- partition_by: vec![col("col1")],
- order_by: vec![col("col2")],
- window_frame: row_number_frame,
- };
+ let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::AggregateFunction(AggregateFunction::Max),
+ vec![col("col1")],
+ vec![col("col1")],
+ vec![col("col2")],
+ row_number_frame,
+ ));
roundtrip_expr_test(test_expr1, ctx.clone());
roundtrip_expr_test(test_expr2, ctx.clone());
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs
index 1e9600564..413e989fa 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -35,7 +35,7 @@ use arrow::datatypes::{
};
use datafusion_common::{Column, DFField, DFSchemaRef, OwnedTableReference, ScalarValue};
use datafusion_expr::expr::{
- Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like, Sort,
+ self, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like, Sort,
};
use datafusion_expr::{
logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction,
@@ -533,13 +533,13 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
expr_type: Some(ExprType::SimilarTo(pb)),
}
}
- Expr::WindowFunction {
+ Expr::WindowFunction(expr::WindowFunction {
ref fun,
ref args,
ref partition_by,
ref order_by,
ref window_frame,
- } => {
+ }) => {
let window_function = match fun {
WindowFunction::AggregateFunction(fun) => {
protobuf::window_expr_node::WindowFunction::AggrFunction(
@@ -581,12 +581,12 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
expr_type: Some(ExprType::WindowExpr(window_expr)),
}
}
- Expr::AggregateFunction {
+ Expr::AggregateFunction(expr::AggregateFunction {
ref fun,
ref args,
ref distinct,
ref filter
- } => {
+ }) => {
let aggr_function = match fun {
AggregateFunction::ApproxDistinct => {
protobuf::AggregateFunction::ApproxDistinct
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 60ee7f5ec..6761eab48 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -43,7 +43,7 @@ use datafusion_common::{
};
use datafusion_common::{OwnedTableReference, TableReference};
use datafusion_expr::expr::{
- Between, BinaryExpr, Case, Cast, GroupingSet, Like, Sort, TryCast,
+ self, Between, BinaryExpr, Case, Cast, GroupingSet, Like, Sort, TryCast,
};
use datafusion_expr::expr_rewriter::normalize_col;
use datafusion_expr::expr_rewriter::normalize_col_with_schemas;
@@ -2260,9 +2260,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLExpr::AggregateExpressionWithFilter { expr, filter } => {
match self.sql_expr_to_logical_expr(*expr, schema, planner_context)? {
- Expr::AggregateFunction {
+ Expr::AggregateFunction(expr::AggregateFunction {
fun, args, distinct, ..
- } => Ok(Expr::AggregateFunction { fun, args, distinct, filter: Some(Box::new(self.sql_expr_to_logical_expr(*filter, schema, planner_context)?)) }),
+ }) => Ok(Expr::AggregateFunction(expr::AggregateFunction::new( fun, args, distinct, Some(Box::new(self.sql_expr_to_logical_expr(*filter, schema, planner_context)?)) ))),
_ => Err(DataFusionError::Internal("AggregateExpressionWithFilter expression was not an AggregateFunction".to_string()))
}
}
@@ -2334,24 +2334,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
schema,
)?;
- Expr::WindowFunction {
- fun: WindowFunction::AggregateFunction(
+ Expr::WindowFunction(expr::WindowFunction::new(
+ WindowFunction::AggregateFunction(
aggregate_fun,
),
args,
partition_by,
order_by,
window_frame,
- }
+ ))
}
_ => {
- Expr::WindowFunction {
+ Expr::WindowFunction(expr::WindowFunction::new(
fun,
- args: self.function_args_to_expr(function.args, schema)?,
+ self.function_args_to_expr(function.args, schema)?,
partition_by,
order_by,
window_frame,
- }
+ ))
}
};
return Ok(expr);
@@ -2361,12 +2361,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
if let Ok(fun) = AggregateFunction::from_str(&name) {
let distinct = function.distinct;
let (fun, args) = self.aggregate_fn_to_expr(fun, function.args, schema)?;
- return Ok(Expr::AggregateFunction {
+ return Ok(Expr::AggregateFunction(expr::AggregateFunction::new(
fun,
- distinct,
args,
- filter: None,
- });
+ distinct,
+ None,
+ )));
};
// finally, user-defined functions (UDF) and UDAF
@@ -2528,12 +2528,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// next, aggregate built-ins
let fun = AggregateFunction::ArrayAgg;
- Ok(Expr::AggregateFunction {
- fun,
- distinct,
- args,
- filter: None,
- })
+ Ok(Expr::AggregateFunction(expr::AggregateFunction::new(
+ fun, args, distinct, None,
+ )))
}
fn function_args_to_expr(
diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs
index 1934b8f0d..a976712ac 100644
--- a/datafusion/sql/src/utils.rs
+++ b/datafusion/sql/src/utils.rs
@@ -22,7 +22,8 @@ use sqlparser::ast::Ident;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{
- Between, BinaryExpr, Case, GetIndexedField, GroupingSet, Like,
+ AggregateFunction, Between, BinaryExpr, Case, GetIndexedField, GroupingSet, Like,
+ WindowFunction,
};
use datafusion_expr::expr::{Cast, Sort};
use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
@@ -162,42 +163,40 @@ where
// No replacement was provided, clone the node and recursively call
// clone_with_replacement() on any nested expressions.
None => match expr {
- Expr::AggregateFunction {
+ Expr::AggregateFunction(AggregateFunction {
fun,
args,
distinct,
filter,
- } => Ok(Expr::AggregateFunction {
- fun: fun.clone(),
- args: args
- .iter()
+ }) => Ok(Expr::AggregateFunction(AggregateFunction::new(
+ fun.clone(),
+ args.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<Expr>>>()?,
- distinct: *distinct,
- filter: filter.clone(),
- }),
- Expr::WindowFunction {
+ *distinct,
+ filter.clone(),
+ ))),
+ Expr::WindowFunction(WindowFunction {
fun,
args,
partition_by,
order_by,
window_frame,
- } => Ok(Expr::WindowFunction {
- fun: fun.clone(),
- args: args
- .iter()
+ }) => Ok(Expr::WindowFunction(WindowFunction::new(
+ fun.clone(),
+ args.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<_>>>()?,
- partition_by: partition_by
+ partition_by
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<_>>>()?,
- order_by: order_by
+ order_by
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<_>>>()?,
- window_frame: window_frame.clone(),
- }),
+ window_frame.clone(),
+ ))),
Expr::AggregateUDF { fun, args, filter } => Ok(Expr::AggregateUDF {
fun: fun.clone(),
args: args
@@ -481,11 +480,13 @@ pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr
let all_partition_keys = window_exprs
.iter()
.map(|expr| match expr {
- Expr::WindowFunction { partition_by, .. } => Ok(partition_by),
+ Expr::WindowFunction(WindowFunction { partition_by, .. }) => Ok(partition_by),
Expr::Alias(expr, _) => {
// convert &Box<T> to &T
match &**expr {
- Expr::WindowFunction { partition_by, .. } => Ok(partition_by),
+ Expr::WindowFunction(WindowFunction { partition_by, .. }) => {
+ Ok(partition_by)
+ }
expr => Err(DataFusionError::Execution(format!(
"Impossibly got non-window expr {:?}",
expr