You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/06/03 21:19:04 UTC
[arrow-datafusion] branch master updated: fix window expression
with alias (#463)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new e82d053 fix window expression with alias (#463)
e82d053 is described below
commit e82d053b526d669e9c845e3fda70147aaf7d3488
Author: Jiayu Liu <Ji...@users.noreply.github.com>
AuthorDate: Fri Jun 4 05:18:53 2021 +0800
fix window expression with alias (#463)
---
ballista/rust/core/proto/ballista.proto | 14 +-
.../rust/core/src/serde/logical_plan/from_proto.rs | 12 +-
.../rust/core/src/serde/logical_plan/to_proto.rs | 39 ++--
.../core/src/serde/physical_plan/from_proto.rs | 14 +-
datafusion/src/logical_plan/builder.rs | 24 +--
datafusion/src/logical_plan/expr.rs | 23 ++-
datafusion/src/logical_plan/plan.rs | 28 +--
datafusion/src/optimizer/projection_push_down.rs | 45 ++---
datafusion/src/optimizer/utils.rs | 39 ++--
datafusion/src/physical_plan/planner.rs | 7 +-
datafusion/src/sql/mod.rs | 2 +-
datafusion/src/sql/planner.rs | 220 +++++++++++++++++----
datafusion/src/sql/utils.rs | 211 +++++++++++++++++++-
13 files changed, 508 insertions(+), 170 deletions(-)
diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto
index 0387214..d21cbf6 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -174,6 +174,12 @@ message WindowExprNode {
// udaf = 3
}
LogicalExprNode expr = 4;
+ // repeated LogicalExprNode partition_by = 5;
+ repeated LogicalExprNode order_by = 6;
+ // repeated LogicalExprNode filter = 7;
+ // oneof window_frame {
+ // WindowFrame frame = 8;
+ // }
}
message BetweenNode {
@@ -317,14 +323,6 @@ message AggregateNode {
message WindowNode {
LogicalPlanNode input = 1;
repeated LogicalExprNode window_expr = 2;
- repeated LogicalExprNode partition_by_expr = 3;
- repeated LogicalExprNode order_by_expr = 4;
- // "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
- // this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
- oneof window_frame {
- WindowFrame frame = 5;
- }
- // TODO add filter by expr
}
enum WindowFrameUnits {
diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
index 4847126..522d60c 100644
--- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
@@ -98,9 +98,7 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
// // FIXME: parse the window_frame data
// let window_frame = None;
LogicalPlanBuilder::from(&input)
- .window(
- window_expr, /* filter_by_expr, partition_by_expr, order_by_expr, window_frame*/
- )?
+ .window(window_expr)?
.build()
.map_err(|e| e.into())
}
@@ -924,6 +922,12 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
.window_function
.as_ref()
.ok_or_else(|| proto_error("Received empty window function"))?;
+ let order_by = expr
+ .order_by
+ .iter()
+ .map(|e| e.try_into())
+ .into_iter()
+ .collect::<Result<Vec<_>, _>>()?;
match window_function {
window_expr_node::WindowFunction::AggrFunction(i) => {
let aggr_function = protobuf::AggregateFunction::from_i32(*i)
@@ -939,6 +943,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
AggregateFunction::from(aggr_function),
),
args: vec![parse_required_expr(&expr.expr)?],
+ order_by,
})
}
window_expr_node::WindowFunction::BuiltInFunction(i) => {
@@ -957,6 +962,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
BuiltInWindowFunction::from(built_in_function),
),
args: vec![parse_required_expr(&expr.expr)?],
+ order_by,
})
}
}
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index e1c0c5e..088e931 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -761,27 +761,9 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
})
}
LogicalPlan::Window {
- input,
- window_expr,
- // FIXME implement next
- // filter_by_expr,
- // FIXME implement next
- // partition_by_expr,
- // FIXME implement next
- // order_by_expr,
- // FIXME implement next
- // window_frame,
- ..
+ input, window_expr, ..
} => {
let input: protobuf::LogicalPlanNode = input.as_ref().try_into()?;
- // FIXME: implement
- // let filter_by_expr = vec![];
- // FIXME: implement
- let partition_by_expr = vec![];
- // FIXME: implement
- let order_by_expr = vec![];
- // FIXME: implement
- let window_frame = None;
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::Window(Box::new(
protobuf::WindowNode {
@@ -789,10 +771,7 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
window_expr: window_expr
.iter()
.map(|expr| expr.try_into())
- .collect::<Result<Vec<_>, BallistaError>>()?,
- partition_by_expr,
- order_by_expr,
- window_frame,
+ .collect::<Result<Vec<_>, _>>()?,
},
))),
})
@@ -811,11 +790,11 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
group_expr: group_expr
.iter()
.map(|expr| expr.try_into())
- .collect::<Result<Vec<_>, BallistaError>>()?,
+ .collect::<Result<Vec<_>, _>>()?,
aggr_expr: aggr_expr
.iter()
.map(|expr| expr.try_into())
- .collect::<Result<Vec<_>, BallistaError>>()?,
+ .collect::<Result<Vec<_>, _>>()?,
},
))),
})
@@ -1024,7 +1003,10 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
})
}
Expr::WindowFunction {
- ref fun, ref args, ..
+ ref fun,
+ ref args,
+ ref order_by,
+ ..
} => {
let window_function = match fun {
WindowFunction::AggregateFunction(fun) => {
@@ -1039,9 +1021,14 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
}
};
let arg = &args[0];
+ let order_by = order_by
+ .iter()
+ .map(|e| e.try_into())
+ .collect::<Result<Vec<_>, _>>()?;
let window_expr = Box::new(protobuf::WindowExprNode {
expr: Some(Box::new(arg.try_into()?)),
window_function: Some(window_function),
+ order_by,
});
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::WindowExpr(window_expr)),
diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
index 7f98a83..c19739a 100644
--- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
@@ -233,7 +233,11 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
for (expr, name) in &window_agg_expr {
match expr {
- Expr::WindowFunction { fun, args } => {
+ Expr::WindowFunction {
+ fun,
+ args,
+ order_by,
+ } => {
let arg = df_planner
.create_physical_expr(
&args[0],
@@ -243,12 +247,16 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
.map_err(|e| {
BallistaError::General(format!("{:?}", e))
})?;
- physical_window_expr.push(create_window_expr(
+ if !order_by.is_empty() {
+ return Err(BallistaError::NotImplemented("Window function with order by is not yet implemented".to_owned()));
+ }
+ let window_expr = create_window_expr(
&fun,
&[arg],
&physical_schema,
name.to_owned(),
- )?);
+ )?;
+ physical_window_expr.push(window_expr);
}
_ => {
return Err(BallistaError::General(
diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs
index 71de48c..dc80a41 100644
--- a/datafusion/src/logical_plan/builder.rs
+++ b/datafusion/src/logical_plan/builder.rs
@@ -297,23 +297,7 @@ impl LogicalPlanBuilder {
/// - https://github.com/apache/arrow-datafusion/issues/299 with partition clause
/// - https://github.com/apache/arrow-datafusion/issues/360 with order by
/// - https://github.com/apache/arrow-datafusion/issues/361 with window frame
- pub fn window(
- &self,
- window_expr: impl IntoIterator<Item = Expr>,
- // FIXME: implement next
- // filter_by_expr: impl IntoIterator<Item = Expr>,
- // FIXME: implement next
- // partition_by_expr: impl IntoIterator<Item = Expr>,
- // FIXME: implement next
- // order_by_expr: impl IntoIterator<Item = Expr>,
- // FIXME: implement next
- // window_frame: Option<WindowFrame>,
- ) -> Result<Self> {
- let window_expr = window_expr.into_iter().collect::<Vec<_>>();
- // FIXME: implement next
- // let partition_by_expr = partition_by_expr.into_iter().collect::<Vec<Expr>>();
- // FIXME: implement next
- // let order_by_expr = order_by_expr.into_iter().collect::<Vec<Expr>>();
+ pub fn window(&self, window_expr: Vec<Expr>) -> Result<Self> {
let all_expr = window_expr.iter();
validate_unique_names("Windows", all_expr.clone(), self.plan.schema())?;
@@ -323,12 +307,6 @@ impl LogicalPlanBuilder {
Ok(Self::from(&LogicalPlan::Window {
input: Arc::new(self.plan.clone()),
- // FIXME implement next
- // partition_by_expr,
- // FIXME implement next
- // order_by_expr,
- // FIXME implement next
- // window_frame,
window_expr,
schema: Arc::new(DFSchema::new(window_fields)?),
}))
diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs
index 29723e7..5103d5d 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -197,6 +197,8 @@ pub enum Expr {
fun: window_functions::WindowFunction,
/// List of expressions to feed to the functions as arguments
args: Vec<Expr>,
+ /// List of order by expressions
+ order_by: Vec<Expr>,
},
/// aggregate function
AggregateUDF {
@@ -587,9 +589,15 @@ impl Expr {
Expr::ScalarUDF { args, .. } => args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
- Expr::WindowFunction { args, .. } => args
- .iter()
- .try_fold(visitor, |visitor, arg| arg.accept(visitor)),
+ Expr::WindowFunction { args, order_by, .. } => {
+ let visitor = args
+ .iter()
+ .try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
+ let visitor = order_by
+ .iter()
+ .try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
+ Ok(visitor)
+ }
Expr::AggregateFunction { args, .. } => args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
@@ -723,9 +731,14 @@ impl Expr {
args: rewrite_vec(args, rewriter)?,
fun,
},
- Expr::WindowFunction { args, fun } => Expr::WindowFunction {
+ Expr::WindowFunction {
+ args,
+ fun,
+ order_by,
+ } => Expr::WindowFunction {
args: rewrite_vec(args, rewriter)?,
fun,
+ order_by: rewrite_vec(order_by, rewriter)?,
},
Expr::AggregateFunction {
args,
@@ -1388,7 +1401,7 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
Expr::ScalarUDF { fun, args, .. } => {
create_function_name(&fun.name, false, args, input_schema)
}
- Expr::WindowFunction { fun, args } => {
+ Expr::WindowFunction { fun, args, .. } => {
create_function_name(&fun.to_string(), false, args, input_schema)
}
Expr::AggregateFunction {
diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs
index 5cb94be..fe1dfb6 100644
--- a/datafusion/src/logical_plan/plan.rs
+++ b/datafusion/src/logical_plan/plan.rs
@@ -92,8 +92,6 @@ pub enum LogicalPlan {
// filter_by_expr: Vec<Expr>,
/// Partition by expressions
// partition_by_expr: Vec<Expr>,
- /// Order by expressions
- // order_by_expr: Vec<Expr>,
/// Window Frame
// window_frame: Option<WindowFrame>,
/// The schema description of the window output
@@ -306,25 +304,12 @@ impl LogicalPlan {
Partitioning::Hash(expr, _) => expr.clone(),
_ => vec![],
},
- LogicalPlan::Window {
- window_expr,
- // FIXME implement next
- // filter_by_expr,
- // FIXME implement next
- // partition_by_expr,
- // FIXME implement next
- // order_by_expr,
- ..
- } => window_expr.clone(),
+ LogicalPlan::Window { window_expr, .. } => window_expr.clone(),
LogicalPlan::Aggregate {
group_expr,
aggr_expr,
..
- } => {
- let mut result = group_expr.clone();
- result.extend(aggr_expr.clone());
- result
- }
+ } => group_expr.iter().chain(aggr_expr.iter()).cloned().collect(),
LogicalPlan::Join { on, .. } => {
on.iter().flat_map(|(l, r)| vec![col(l), col(r)]).collect()
}
@@ -698,16 +683,11 @@ impl LogicalPlan {
..
} => write!(f, "Filter: {:?}", expr),
LogicalPlan::Window {
- ref window_expr,
- // FIXME implement next
- // ref partition_by_expr,
- // FIXME implement next
- // ref order_by_expr,
- ..
+ ref window_expr, ..
} => {
write!(
f,
- "WindowAggr: windowExpr=[{:?}] partitionBy=[], orderBy=[]",
+ "WindowAggr: windowExpr=[{:?}] partitionBy=[]",
window_expr
)
}
diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs
index e47832b..f0b364a 100644
--- a/datafusion/src/optimizer/projection_push_down.rs
+++ b/datafusion/src/optimizer/projection_push_down.rs
@@ -23,6 +23,7 @@ use crate::execution::context::ExecutionProps;
use crate::logical_plan::{DFField, DFSchema, DFSchemaRef, LogicalPlan, ToDFSchema};
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
+use crate::sql::utils::find_sort_exprs;
use arrow::datatypes::Schema;
use arrow::error::Result as ArrowResult;
use std::{collections::HashSet, sync::Arc};
@@ -197,29 +198,29 @@ fn optimize_plan(
schema,
window_expr,
input,
- // FIXME implement next
- // filter_by_expr,
- // FIXME implement next
- // partition_by_expr,
- // FIXME implement next
- // order_by_expr,
- // FIXME implement next
- // window_frame,
..
} => {
// Gather all columns needed for expressions in this Window
let mut new_window_expr = Vec::new();
- window_expr.iter().try_for_each(|expr| {
- let name = &expr.name(&schema)?;
- if required_columns.contains(name) {
- new_window_expr.push(expr.clone());
- new_required_columns.insert(name.clone());
- // add to the new set of required columns
- utils::expr_to_column_names(expr, &mut new_required_columns)
- } else {
- Ok(())
- }
- })?;
+ {
+ window_expr.iter().try_for_each(|expr| {
+ let name = &expr.name(&schema)?;
+ if required_columns.contains(name) {
+ new_window_expr.push(expr.clone());
+ new_required_columns.insert(name.clone());
+ // add to the new set of required columns
+ utils::expr_to_column_names(expr, &mut new_required_columns)
+ } else {
+ Ok(())
+ }
+ })?;
+ }
+
+ // for all the retained window expr, find their sort expressions if any, and retain these
+ utils::exprlist_to_column_names(
+ &find_sort_exprs(&new_window_expr),
+ &mut new_required_columns,
+ )?;
let new_schema = DFSchema::new(
schema
@@ -232,12 +233,6 @@ fn optimize_plan(
Ok(LogicalPlan::Window {
window_expr: new_window_expr,
- // FIXME implement next
- // partition_by_expr: partition_by_expr.clone(),
- // FIXME implement next
- // order_by_expr: order_by_expr.clone(),
- // FIXME implement next
- // window_frame: window_frame.clone(),
input: Arc::new(optimize_plan(
optimizer,
&input,
diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs
index 284ead2..2cb6506 100644
--- a/datafusion/src/optimizer/utils.rs
+++ b/datafusion/src/optimizer/utils.rs
@@ -36,6 +36,7 @@ use crate::{
const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__";
const CASE_ELSE_MARKER: &str = "__DATAFUSION_CASE_ELSE__";
+const WINDOW_SORT_MARKER: &str = "__DATAFUSION_WINDOW_SORT__";
/// Recursively walk a list of expression trees, collecting the unique set of column
/// names referenced in the expression
@@ -190,14 +191,6 @@ pub fn from_plan(
}),
},
LogicalPlan::Window {
- // FIXME implement next
- // filter_by_expr,
- // FIXME implement next
- // partition_by_expr,
- // FIXME implement next
- // order_by_expr,
- // FIXME implement next
- // window_frame,
window_expr,
schema,
..
@@ -265,7 +258,13 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
Expr::IsNotNull(e) => Ok(vec![e.as_ref().to_owned()]),
Expr::ScalarFunction { args, .. } => Ok(args.clone()),
Expr::ScalarUDF { args, .. } => Ok(args.clone()),
- Expr::WindowFunction { args, .. } => Ok(args.clone()),
+ Expr::WindowFunction { args, order_by, .. } => {
+ let mut expr_list: Vec<Expr> = vec![];
+ expr_list.extend(args.clone());
+ expr_list.push(lit(WINDOW_SORT_MARKER));
+ expr_list.extend(order_by.clone());
+ Ok(expr_list)
+ }
Expr::AggregateFunction { args, .. } => Ok(args.clone()),
Expr::AggregateUDF { args, .. } => Ok(args.clone()),
Expr::Case {
@@ -338,10 +337,24 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
fun: fun.clone(),
args: expressions.to_vec(),
}),
- Expr::WindowFunction { fun, .. } => Ok(Expr::WindowFunction {
- fun: fun.clone(),
- args: expressions.to_vec(),
- }),
+ Expr::WindowFunction { fun, .. } => {
+ let index = expressions
+ .iter()
+ .position(|expr| {
+ matches!(expr, Expr::Literal(ScalarValue::Utf8(Some(str)))
+ if str == WINDOW_SORT_MARKER)
+ })
+ .ok_or_else(|| {
+ DataFusionError::Internal(
+ "Ill-formed window function expressions".to_owned(),
+ )
+ })?;
+ Ok(Expr::WindowFunction {
+ fun: fun.clone(),
+ args: expressions[..index].to_vec(),
+ order_by: expressions[index + 1..].to_vec(),
+ })
+ }
Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction {
fun: fun.clone(),
args: expressions.to_vec(),
diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs
index 4971a02..b77850f 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -746,13 +746,18 @@ impl DefaultPhysicalPlanner {
};
match e {
- Expr::WindowFunction { fun, args } => {
+ Expr::WindowFunction { fun, args, .. } => {
let args = args
.iter()
.map(|e| {
self.create_physical_expr(e, physical_input_schema, ctx_state)
})
.collect::<Result<Vec<_>>>()?;
+ // if !order_by.is_empty() {
+ // return Err(DataFusionError::NotImplemented(
+ // "Window function with order by is not yet implemented".to_owned(),
+ // ));
+ // }
windows::create_window_expr(fun, &args, physical_input_schema, name)
}
other => Err(DataFusionError::Internal(format!(
diff --git a/datafusion/src/sql/mod.rs b/datafusion/src/sql/mod.rs
index 456ad4c..cc8b004 100644
--- a/datafusion/src/sql/mod.rs
+++ b/datafusion/src/sql/mod.rs
@@ -20,4 +20,4 @@
pub mod parser;
pub mod planner;
-mod utils;
+pub(crate) mod utils;
diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs
index 63499aa..3b8acc6 100644
--- a/datafusion/src/sql/planner.rs
+++ b/datafusion/src/sql/planner.rs
@@ -54,8 +54,8 @@ use super::{
parser::DFParser,
utils::{
can_columns_satisfy_exprs, expand_wildcard, expr_as_column_expr, extract_aliases,
- find_aggregate_exprs, find_column_exprs, find_window_exprs, rebase_expr,
- resolve_aliases_to_exprs,
+ find_aggregate_exprs, find_column_exprs, find_window_exprs,
+ group_window_expr_by_sort_keys, rebase_expr, resolve_aliases_to_exprs,
},
};
@@ -628,7 +628,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let (plan, exprs) = if window_func_exprs.is_empty() {
(plan, select_exprs_post_aggr)
} else {
- self.window(&plan, window_func_exprs, &select_exprs_post_aggr)?
+ self.window(plan, window_func_exprs, &select_exprs_post_aggr)?
};
let plan = if select.distinct {
@@ -670,13 +670,28 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
/// Wrap a plan in a window
fn window(
&self,
- input: &LogicalPlan,
+ input: LogicalPlan,
window_exprs: Vec<Expr>,
select_exprs: &[Expr],
) -> Result<(LogicalPlan, Vec<Expr>)> {
- let plan = LogicalPlanBuilder::from(input)
- .window(window_exprs.clone())?
- .build()?;
+ let mut plan = input;
+ let mut groups = group_window_expr_by_sort_keys(&window_exprs)?;
+ // sort by sort_key len descending, so that more deeply sorted plans gets nested further
+ // down as children; to further minic the behavior of PostgreSQL, we want stable sort
+ // and a reverse so that tieing sort keys are reversed in order; note that by this rule
+ // if there's an empty over, it'll be at the top level
+ groups.sort_by(|(key_a, _), (key_b, _)| key_a.len().cmp(&key_b.len()));
+ groups.reverse();
+ for (sort_keys, exprs) in groups {
+ if !sort_keys.is_empty() {
+ let sort_keys: Vec<Expr> = sort_keys.to_vec();
+ plan = LogicalPlanBuilder::from(&plan).sort(sort_keys)?.build()?;
+ }
+ let window_exprs: Vec<Expr> = exprs.into_iter().cloned().collect();
+ plan = LogicalPlanBuilder::from(&plan)
+ .window(window_exprs)?
+ .build()?;
+ }
let select_exprs = select_exprs
.iter()
.map(|expr| rebase_expr(expr, &window_exprs, &plan))
@@ -779,21 +794,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
return Ok(plan.clone());
}
- let input_schema = plan.schema();
- let order_by_rex: Result<Vec<Expr>> = order_by
+ let order_by_rex = order_by
.iter()
- .map(|e| {
- Ok(Expr::Sort {
- expr: Box::new(self.sql_to_rex(&e.expr, &input_schema)?),
- // by default asc
- asc: e.asc.unwrap_or(true),
- // by default nulls first to be consistent with spark
- nulls_first: e.nulls_first.unwrap_or(true),
- })
- })
- .collect();
+ .map(|e| self.order_by_to_sort_expr(e))
+ .into_iter()
+ .collect::<Result<Vec<_>>>()?;
- LogicalPlanBuilder::from(&plan).sort(order_by_rex?)?.build()
+ LogicalPlanBuilder::from(&plan).sort(order_by_rex)?.build()
+ }
+
+ /// convert sql OrderByExpr to Expr::Sort
+ fn order_by_to_sort_expr(&self, e: &OrderByExpr) -> Result<Expr> {
+ Ok(Expr::Sort {
+ expr: Box::new(self.sql_expr_to_logical_expr(&e.expr)?),
+ // by default asc
+ asc: e.asc.unwrap_or(true),
+ // by default nulls first to be consistent with spark
+ nulls_first: e.nulls_first.unwrap_or(true),
+ })
}
/// Validate the schema provides all of the columns referenced in the expressions.
@@ -982,7 +1000,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
UnaryOperator::Plus => Ok(self.sql_expr_to_logical_expr(expr)?),
UnaryOperator::Minus => {
match expr.as_ref() {
- // optimization: if it's a number literal, we applly the negative operator
+ // optimization: if it's a number literal, we apply the negative operator
// here directly to calculate the new literal.
SQLExpr::Value(Value::Number(n,_)) => match n.parse::<i64>() {
Ok(n) => Ok(lit(-n)),
@@ -1091,10 +1109,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// then, window function
if let Some(window) = &function.over {
- if window.partition_by.is_empty()
- && window.order_by.is_empty()
- && window.window_frame.is_none()
- {
+ if window.partition_by.is_empty() && window.window_frame.is_none() {
+ let order_by = window
+ .order_by
+ .iter()
+ .map(|e| self.order_by_to_sort_expr(e))
+ .into_iter()
+ .collect::<Result<Vec<_>>>()?;
let fun = window_functions::WindowFunction::from_str(&name);
if let Ok(window_functions::WindowFunction::AggregateFunction(
aggregate_fun,
@@ -1106,6 +1127,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
),
args: self
.aggregate_fn_to_expr(&aggregate_fun, function)?,
+ order_by,
});
} else if let Ok(
window_functions::WindowFunction::BuiltInWindowFunction(
@@ -1118,6 +1140,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
window_fun,
),
args:self.function_args_to_expr(function)?,
+ order_by
});
}
}
@@ -2702,7 +2725,7 @@ mod tests {
let sql = "SELECT order_id, MAX(order_id) OVER () from orders";
let expected = "\
Projection: #order_id, #MAX(order_id)\
- \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[], orderBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}
@@ -2712,7 +2735,7 @@ mod tests {
let sql = "SELECT order_id oid, MAX(order_id) OVER () max_oid from orders";
let expected = "\
Projection: #order_id AS oid, #MAX(order_id) AS max_oid\
- \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[], orderBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}
@@ -2722,7 +2745,7 @@ mod tests {
let sql = "SELECT order_id, MAX(qty * 1.1) OVER () from orders";
let expected = "\
Projection: #order_id, #MAX(qty Multiply Float64(1.1))\
- \n WindowAggr: windowExpr=[[MAX(#qty Multiply Float64(1.1))]] partitionBy=[], orderBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#qty Multiply Float64(1.1))]] partitionBy=[]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}
@@ -2733,7 +2756,7 @@ mod tests {
"SELECT order_id, MAX(qty) OVER (), min(qty) over (), aVg(qty) OVER () from orders";
let expected = "\
Projection: #order_id, #MAX(qty), #MIN(qty), #AVG(qty)\
- \n WindowAggr: windowExpr=[[MAX(#qty), MIN(#qty), AVG(#qty)]] partitionBy=[], orderBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#qty), MIN(#qty), AVG(#qty)]] partitionBy=[]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}
@@ -2749,14 +2772,139 @@ mod tests {
);
}
+ /// psql result
+ /// ```
+ /// QUERY PLAN
+ /// ----------------------------------------------------------------------------------
+ /// WindowAgg (cost=137.16..154.66 rows=1000 width=12)
+ /// -> Sort (cost=137.16..139.66 rows=1000 width=12)
+ /// Sort Key: order_id
+ /// -> WindowAgg (cost=69.83..87.33 rows=1000 width=12)
+ /// -> Sort (cost=69.83..72.33 rows=1000 width=8)
+ /// Sort Key: order_id DESC
+ /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8)
+ /// ```
+ #[test]
+ fn over_order_by() {
+ let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY order_id DESC) from orders";
+ let expected = "\
+ Projection: #order_id, #MAX(qty), #MIN(qty)\
+ \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\
+ \n Sort: #order_id ASC NULLS FIRST\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n Sort: #order_id DESC NULLS FIRST\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
+ }
+
+ /// psql result
+ /// ```
+ /// QUERY PLAN
+ /// -----------------------------------------------------------------------------------
+ /// WindowAgg (cost=142.16..162.16 rows=1000 width=16)
+ /// -> Sort (cost=142.16..144.66 rows=1000 width=16)
+ /// Sort Key: order_id
+ /// -> WindowAgg (cost=72.33..92.33 rows=1000 width=16)
+ /// -> Sort (cost=72.33..74.83 rows=1000 width=12)
+ /// Sort Key: ((order_id + 1))
+ /// -> Seq Scan on orders (cost=0.00..22.50 rows=1000 width=12)
+ /// ```
+ #[test]
+ fn over_order_by_two_sort_keys() {
+ let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY (order_id + 1)) from orders";
+ let expected = "\
+ Projection: #order_id, #MAX(qty), #MIN(qty)\
+ \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\
+ \n Sort: #order_id ASC NULLS FIRST\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n Sort: #order_id Plus Int64(1) ASC NULLS FIRST\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
+ }
+
+ /// psql result
+ /// ```
+ /// QUERY PLAN
+ /// ----------------------------------------------------------------------------------------
+ /// WindowAgg (cost=139.66..172.16 rows=1000 width=24)
+ /// -> WindowAgg (cost=139.66..159.66 rows=1000 width=16)
+ /// -> Sort (cost=139.66..142.16 rows=1000 width=12)
+ /// Sort Key: qty, order_id
+ /// -> WindowAgg (cost=69.83..89.83 rows=1000 width=12)
+ /// -> Sort (cost=69.83..72.33 rows=1000 width=8)
+ /// Sort Key: order_id, qty
+ /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8)
+ /// ```
+ #[test]
+ fn over_order_by_sort_keys_sorting() {
+ let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY qty, order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders";
+ let expected = "\
+ Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\
+ \n WindowAggr: windowExpr=[[SUM(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\
+ \n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
+ }
+
+ /// psql result
+ /// ```
+ /// QUERY PLAN
+ /// ----------------------------------------------------------------------------------
+ /// WindowAgg (cost=69.83..117.33 rows=1000 width=24)
+ /// -> WindowAgg (cost=69.83..104.83 rows=1000 width=16)
+ /// -> WindowAgg (cost=69.83..89.83 rows=1000 width=12)
+ /// -> Sort (cost=69.83..72.33 rows=1000 width=8)
+ /// Sort Key: order_id, qty
+ /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8)
+ /// ```
+ ///
+ /// FIXME: for now we are not detecting prefix of sorting keys in order to save one sort exec phase
#[test]
- fn over_order_by_not_supported() {
- let sql = "SELECT order_id, MAX(delivered) OVER (order BY order_id) from orders";
- let err = logical_plan(sql).expect_err("query should have failed");
- assert_eq!(
- "NotImplemented(\"Unsupported OVER clause (ORDER BY order_id)\")",
- format!("{:?}", err)
- );
+ fn over_order_by_sort_keys_sorting_prefix_compacting() {
+ let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders";
+ let expected = "\
+ Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\
+ \n WindowAggr: windowExpr=[[SUM(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\
+ \n Sort: #order_id ASC NULLS FIRST\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
+ }
+
+ /// psql result
+ /// ```
+ /// QUERY PLAN
+ /// ----------------------------------------------------------------------------------------
+ /// WindowAgg (cost=139.66..172.16 rows=1000 width=24)
+ /// -> WindowAgg (cost=139.66..159.66 rows=1000 width=16)
+ /// -> Sort (cost=139.66..142.16 rows=1000 width=12)
+ /// Sort Key: order_id, qty
+ /// -> WindowAgg (cost=69.83..89.83 rows=1000 width=12)
+ /// -> Sort (cost=69.83..72.33 rows=1000 width=8)
+ /// Sort Key: qty, order_id
+ /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8)
+ /// ```
+ ///
+ /// FIXME: for now we are not detecting prefix of sorting keys in order to re-arrange with global
+ /// sort
+ #[test]
+ fn over_order_by_sort_keys_sorting_global_order_compacting() {
+ let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY qty, order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders ORDER BY order_id";
+ let expected = "\
+ Sort: #order_id ASC NULLS FIRST\
+ \n Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\
+ \n WindowAggr: windowExpr=[[SUM(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\
+ \n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
}
#[test]
diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs
index 70b9df0..80a25d0 100644
--- a/datafusion/src/sql/utils.rs
+++ b/datafusion/src/sql/utils.rs
@@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.
+//! SQL Utility Functions
+
use crate::logical_plan::{DFSchema, Expr, LogicalPlan};
use crate::{
error::{DataFusionError, Result},
@@ -46,6 +48,14 @@ pub(crate) fn find_aggregate_exprs(exprs: &[Expr]) -> Vec<Expr> {
})
}
+/// Collect all deeply nested `Expr::Sort`. They are returned in order of occurrence
+/// (depth first), with duplicates omitted.
+pub(crate) fn find_sort_exprs(exprs: &[Expr]) -> Vec<Expr> {
+ find_exprs_in_exprs(exprs, &|nested_expr| {
+ matches!(nested_expr, Expr::Sort { .. })
+ })
+}
+
/// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence
/// (depth first), with duplicates omitted.
pub(crate) fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> {
@@ -225,12 +235,20 @@ where
.collect::<Result<Vec<Expr>>>()?,
distinct: *distinct,
}),
- Expr::WindowFunction { fun, args } => Ok(Expr::WindowFunction {
+ Expr::WindowFunction {
+ fun,
+ args,
+ order_by,
+ } => Ok(Expr::WindowFunction {
fun: fun.clone(),
args: args
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
- .collect::<Result<Vec<Expr>>>()?,
+ .collect::<Result<Vec<_>>>()?,
+ order_by: order_by
+ .iter()
+ .map(|e| clone_with_replacement(e, replacement_fn))
+ .collect::<Result<Vec<_>>>()?,
}),
Expr::AggregateUDF { fun, args } => Ok(Expr::AggregateUDF {
fun: fun.clone(),
@@ -389,3 +407,192 @@ pub(crate) fn resolve_aliases_to_exprs(
_ => Ok(None),
})
}
+
+/// group a slice of window expression expr by their order by expressions
+pub(crate) fn group_window_expr_by_sort_keys(
+ window_expr: &[Expr],
+) -> Result<Vec<(&[Expr], Vec<&Expr>)>> {
+ let mut result = vec![];
+ window_expr.iter().try_for_each(|expr| match expr {
+ Expr::WindowFunction { order_by, .. } => {
+ if let Some((_, values)) = result.iter_mut().find(
+ |group: &&mut (&[Expr], Vec<&Expr>)| matches!(group, (key, _) if key == order_by),
+ ) {
+ values.push(expr);
+ } else {
+ result.push((order_by, vec![expr]))
+ }
+ Ok(())
+ }
+ other => Err(DataFusionError::Internal(format!(
+ "Impossibly got non-window expr {:?}",
+ other,
+ ))),
+ })?;
+ Ok(result)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::logical_plan::col;
+ use crate::physical_plan::aggregates::AggregateFunction;
+ use crate::physical_plan::window_functions::WindowFunction;
+
+ #[test]
+ fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
+ let result = group_window_expr_by_sort_keys(&[])?;
+ let expected: Vec<(&[Expr], Vec<&Expr>)> = vec![];
+ assert_eq!(expected, result);
+ Ok(())
+ }
+
+ #[test]
+ fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
+ let max1 = Expr::WindowFunction {
+ fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
+ args: vec![col("name")],
+ order_by: vec![],
+ };
+ let max2 = Expr::WindowFunction {
+ fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
+ args: vec![col("name")],
+ order_by: vec![],
+ };
+ let min3 = Expr::WindowFunction {
+ fun: WindowFunction::AggregateFunction(AggregateFunction::Min),
+ args: vec![col("name")],
+ order_by: vec![],
+ };
+ let sum4 = Expr::WindowFunction {
+ fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
+ args: vec![col("age")],
+ order_by: vec![],
+ };
+ // FIXME use as_ref
+ let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
+ let result = group_window_expr_by_sort_keys(exprs)?;
+ let key = &[];
+ let expected: Vec<(&[Expr], Vec<&Expr>)> =
+ vec![(key, vec![&max1, &max2, &min3, &sum4])];
+ assert_eq!(expected, result);
+ Ok(())
+ }
+
+ #[test]
+ fn test_group_window_expr_by_sort_keys() -> Result<()> {
+ let age_asc = Expr::Sort {
+ expr: Box::new(col("age")),
+ asc: true,
+ nulls_first: true,
+ };
+ let name_desc = Expr::Sort {
+ expr: Box::new(col("name")),
+ asc: false,
+ nulls_first: true,
+ };
+ let created_at_desc = Expr::Sort {
+ expr: Box::new(col("created_at")),
+ asc: false,
+ nulls_first: true,
+ };
+ let max1 = Expr::WindowFunction {
+ fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
+ args: vec![col("name")],
+ order_by: vec![age_asc.clone(), name_desc.clone()],
+ };
+ let max2 = Expr::WindowFunction {
+ fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
+ args: vec![col("name")],
+ order_by: vec![],
+ };
+ let min3 = Expr::WindowFunction {
+ fun: WindowFunction::AggregateFunction(AggregateFunction::Min),
+ args: vec![col("name")],
+ order_by: vec![age_asc.clone(), name_desc.clone()],
+ };
+ let sum4 = Expr::WindowFunction {
+ fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
+ args: vec![col("age")],
+ order_by: vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
+ };
+ // FIXME use as_ref
+ let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
+ let result = group_window_expr_by_sort_keys(exprs)?;
+
+ let key1 = &[age_asc.clone(), name_desc.clone()];
+ let key2 = &[];
+ let key3 = &[name_desc, age_asc, created_at_desc];
+
+ let expected: Vec<(&[Expr], Vec<&Expr>)> = vec![
+ (key1, vec![&max1, &min3]),
+ (key2, vec![&max2]),
+ (key3, vec![&sum4]),
+ ];
+ assert_eq!(expected, result);
+ Ok(())
+ }
+
+ #[test]
+ fn test_find_sort_exprs() -> Result<()> {
+ let exprs = &[
+ Expr::WindowFunction {
+ fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
+ args: vec![col("name")],
+ order_by: vec![
+ Expr::Sort {
+ expr: Box::new(col("age")),
+ asc: true,
+ nulls_first: true,
+ },
+ Expr::Sort {
+ expr: Box::new(col("name")),
+ asc: false,
+ nulls_first: true,
+ },
+ ],
+ },
+ Expr::WindowFunction {
+ fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
+ args: vec![col("age")],
+ order_by: vec![
+ Expr::Sort {
+ expr: Box::new(col("name")),
+ asc: false,
+ nulls_first: true,
+ },
+ Expr::Sort {
+ expr: Box::new(col("age")),
+ asc: true,
+ nulls_first: true,
+ },
+ Expr::Sort {
+ expr: Box::new(col("created_at")),
+ asc: false,
+ nulls_first: true,
+ },
+ ],
+ },
+ ];
+ let expected = vec![
+ Expr::Sort {
+ expr: Box::new(col("age")),
+ asc: true,
+ nulls_first: true,
+ },
+ Expr::Sort {
+ expr: Box::new(col("name")),
+ asc: false,
+ nulls_first: true,
+ },
+ Expr::Sort {
+ expr: Box::new(col("created_at")),
+ asc: false,
+ nulls_first: true,
+ },
+ ];
+ let result = find_sort_exprs(exprs);
+ assert_eq!(expected, result);
+ Ok(())
+ }
+}