You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/06/09 18:26:09 UTC
[arrow-datafusion] branch master updated: Add `partition by`
constructs in window functions and modify logical planning (#501)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new d5bca0e Add `partition by` constructs in window functions and modify logical planning (#501)
d5bca0e is described below
commit d5bca0e350d94a1e1063bed8a0da0cb09c6e3e1c
Author: Jiayu Liu <Ji...@users.noreply.github.com>
AuthorDate: Thu Jun 10 02:26:01 2021 +0800
Add `partition by` constructs in window functions and modify logical planning (#501)
* closing up type checks
* add fmt
---
ballista/rust/core/proto/ballista.proto | 2 +-
.../rust/core/src/serde/logical_plan/from_proto.rs | 8 +
.../rust/core/src/serde/logical_plan/to_proto.rs | 6 +
.../core/src/serde/physical_plan/from_proto.rs | 8 +
datafusion/src/logical_plan/expr.rs | 14 +-
datafusion/src/logical_plan/plan.rs | 6 +-
datafusion/src/optimizer/utils.rs | 46 ++++-
datafusion/src/sql/planner.rs | 217 +++++++++++++++------
datafusion/src/sql/utils.rs | 57 ++++--
9 files changed, 280 insertions(+), 84 deletions(-)
diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto
index 38d87e9..85af902 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -174,7 +174,7 @@ message WindowExprNode {
// udaf = 3
}
LogicalExprNode expr = 4;
- // repeated LogicalExprNode partition_by = 5;
+ repeated LogicalExprNode partition_by = 5;
repeated LogicalExprNode order_by = 6;
// repeated LogicalExprNode filter = 7;
oneof window_frame {
diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
index 36a37a1..86daeb0 100644
--- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
@@ -910,6 +910,12 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
.window_function
.as_ref()
.ok_or_else(|| proto_error("Received empty window function"))?;
+ let partition_by = expr
+ .partition_by
+ .iter()
+ .map(|e| e.try_into())
+ .into_iter()
+ .collect::<Result<Vec<_>, _>>()?;
let order_by = expr
.order_by
.iter()
@@ -940,6 +946,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
AggregateFunction::from(aggr_function),
),
args: vec![parse_required_expr(&expr.expr)?],
+ partition_by,
order_by,
window_frame,
})
@@ -960,6 +967,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
BuiltInWindowFunction::from(built_in_function),
),
args: vec![parse_required_expr(&expr.expr)?],
+ partition_by,
order_by,
window_frame,
})
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index fb1383d..5d99684 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -1006,6 +1006,7 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
Expr::WindowFunction {
ref fun,
ref args,
+ ref partition_by,
ref order_by,
ref window_frame,
..
@@ -1023,6 +1024,10 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
}
};
let arg = &args[0];
+ let partition_by = partition_by
+ .iter()
+ .map(|e| e.try_into())
+ .collect::<Result<Vec<_>, _>>()?;
let order_by = order_by
.iter()
.map(|e| e.try_into())
@@ -1035,6 +1040,7 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
let window_expr = Box::new(protobuf::WindowExprNode {
expr: Some(Box::new(arg.try_into()?)),
window_function: Some(window_function),
+ partition_by,
order_by,
window_frame,
});
diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
index 5fcc971..b319d5b 100644
--- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
@@ -236,7 +236,9 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
Expr::WindowFunction {
fun,
args,
+ partition_by,
order_by,
+ window_frame,
..
} => {
let arg = df_planner
@@ -248,9 +250,15 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
.map_err(|e| {
BallistaError::General(format!("{:?}", e))
})?;
+ if !partition_by.is_empty() {
+ return Err(BallistaError::NotImplemented("Window function with partition by is not yet implemented".to_owned()));
+ }
if !order_by.is_empty() {
return Err(BallistaError::NotImplemented("Window function with order by is not yet implemented".to_owned()));
}
+ if window_frame.is_some() {
+ return Err(BallistaError::NotImplemented("Window function with window frame is not yet implemented".to_owned()));
+ }
let window_expr = create_window_expr(
&fun,
&[arg],
diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs
index d5c92db..58dba16 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -194,6 +194,8 @@ pub enum Expr {
fun: window_functions::WindowFunction,
/// List of expressions to feed to the functions as arguments
args: Vec<Expr>,
+ /// List of partition by expressions
+ partition_by: Vec<Expr>,
/// List of order by expressions
order_by: Vec<Expr>,
/// Window frame
@@ -588,10 +590,18 @@ impl Expr {
Expr::ScalarUDF { args, .. } => args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
- Expr::WindowFunction { args, order_by, .. } => {
+ Expr::WindowFunction {
+ args,
+ partition_by,
+ order_by,
+ ..
+ } => {
let visitor = args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
+ let visitor = partition_by
+ .iter()
+ .try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
let visitor = order_by
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
@@ -733,11 +743,13 @@ impl Expr {
Expr::WindowFunction {
args,
fun,
+ partition_by,
order_by,
window_frame,
} => Expr::WindowFunction {
args: rewrite_vec(args, rewriter)?,
fun,
+ partition_by: rewrite_vec(partition_by, rewriter)?,
order_by: rewrite_vec(order_by, rewriter)?,
window_frame,
},
diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs
index 25cf9e3..3344dce 100644
--- a/datafusion/src/logical_plan/plan.rs
+++ b/datafusion/src/logical_plan/plan.rs
@@ -687,11 +687,7 @@ impl LogicalPlan {
LogicalPlan::Window {
ref window_expr, ..
} => {
- write!(
- f,
- "WindowAggr: windowExpr=[{:?}] partitionBy=[]",
- window_expr
- )
+ write!(f, "WindowAggr: windowExpr=[{:?}]", window_expr)
}
LogicalPlan::Aggregate {
ref group_expr,
diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs
index 65c95be..e707d30 100644
--- a/datafusion/src/optimizer/utils.rs
+++ b/datafusion/src/optimizer/utils.rs
@@ -36,6 +36,7 @@ use crate::{
const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__";
const CASE_ELSE_MARKER: &str = "__DATAFUSION_CASE_ELSE__";
+const WINDOW_PARTITION_MARKER: &str = "__DATAFUSION_WINDOW_PARTITION__";
const WINDOW_SORT_MARKER: &str = "__DATAFUSION_WINDOW_SORT__";
/// Recursively walk a list of expression trees, collecting the unique set of column
@@ -258,9 +259,16 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
Expr::IsNotNull(e) => Ok(vec![e.as_ref().to_owned()]),
Expr::ScalarFunction { args, .. } => Ok(args.clone()),
Expr::ScalarUDF { args, .. } => Ok(args.clone()),
- Expr::WindowFunction { args, order_by, .. } => {
+ Expr::WindowFunction {
+ args,
+ partition_by,
+ order_by,
+ ..
+ } => {
let mut expr_list: Vec<Expr> = vec![];
expr_list.extend(args.clone());
+ expr_list.push(lit(WINDOW_PARTITION_MARKER));
+ expr_list.extend(partition_by.clone());
expr_list.push(lit(WINDOW_SORT_MARKER));
expr_list.extend(order_by.clone());
Ok(expr_list)
@@ -340,7 +348,20 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
Expr::WindowFunction {
fun, window_frame, ..
} => {
- let index = expressions
+ let partition_index = expressions
+ .iter()
+ .position(|expr| {
+ matches!(expr, Expr::Literal(ScalarValue::Utf8(Some(str)))
+ if str == WINDOW_PARTITION_MARKER)
+ })
+ .ok_or_else(|| {
+ DataFusionError::Internal(
+ "Ill-formed window function expressions: unexpected marker"
+ .to_owned(),
+ )
+ })?;
+
+ let sort_index = expressions
.iter()
.position(|expr| {
matches!(expr, Expr::Literal(ScalarValue::Utf8(Some(str)))
@@ -351,12 +372,21 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
"Ill-formed window function expressions".to_owned(),
)
})?;
- Ok(Expr::WindowFunction {
- fun: fun.clone(),
- args: expressions[..index].to_vec(),
- order_by: expressions[index + 1..].to_vec(),
- window_frame: *window_frame,
- })
+
+ if partition_index >= sort_index {
+ Err(DataFusionError::Internal(
+ "Ill-formed window function expressions: partition index too large"
+ .to_owned(),
+ ))
+ } else {
+ Ok(Expr::WindowFunction {
+ fun: fun.clone(),
+ args: expressions[..partition_index].to_vec(),
+ partition_by: expressions[partition_index + 1..sort_index].to_vec(),
+ order_by: expressions[sort_index + 1..].to_vec(),
+ window_frame: *window_frame,
+ })
+ }
}
Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction {
fun: fun.clone(),
diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs
index 7df0068..53f22ec 100644
--- a/datafusion/src/sql/planner.rs
+++ b/datafusion/src/sql/planner.rs
@@ -1122,52 +1122,53 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// then, window function
if let Some(window) = &function.over {
- if window.partition_by.is_empty() {
- let order_by = window
- .order_by
- .iter()
- .map(|e| self.order_by_to_sort_expr(e))
- .into_iter()
- .collect::<Result<Vec<_>>>()?;
- let window_frame = window
- .window_frame
- .as_ref()
- .map(|window_frame| window_frame.clone().try_into())
- .transpose()?;
- let fun = window_functions::WindowFunction::from_str(&name);
- if let Ok(window_functions::WindowFunction::AggregateFunction(
+ let partition_by = window
+ .partition_by
+ .iter()
+ .map(|e| self.sql_expr_to_logical_expr(e))
+ .into_iter()
+ .collect::<Result<Vec<_>>>()?;
+ let order_by = window
+ .order_by
+ .iter()
+ .map(|e| self.order_by_to_sort_expr(e))
+ .into_iter()
+ .collect::<Result<Vec<_>>>()?;
+ let window_frame = window
+ .window_frame
+ .as_ref()
+ .map(|window_frame| window_frame.clone().try_into())
+ .transpose()?;
+ let fun = window_functions::WindowFunction::from_str(&name)?;
+ match fun {
+ window_functions::WindowFunction::AggregateFunction(
aggregate_fun,
- )) = fun
- {
+ ) => {
return Ok(Expr::WindowFunction {
fun: window_functions::WindowFunction::AggregateFunction(
aggregate_fun.clone(),
),
args: self
.aggregate_fn_to_expr(&aggregate_fun, function)?,
+ partition_by,
order_by,
window_frame,
});
- } else if let Ok(
- window_functions::WindowFunction::BuiltInWindowFunction(
- window_fun,
- ),
- ) = fun
- {
+ }
+ window_functions::WindowFunction::BuiltInWindowFunction(
+ window_fun,
+ ) => {
return Ok(Expr::WindowFunction {
fun: window_functions::WindowFunction::BuiltInWindowFunction(
window_fun,
),
args: self.function_args_to_expr(function)?,
+ partition_by,
order_by,
window_frame,
});
}
}
- return Err(DataFusionError::NotImplemented(format!(
- "Unsupported OVER clause ({})",
- window
- )));
}
// next, aggregate built-ins
@@ -2775,7 +2776,7 @@ mod tests {
let sql = "SELECT order_id, MAX(order_id) OVER () from orders";
let expected = "\
Projection: #order_id, #MAX(order_id)\
- \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#order_id)]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}
@@ -2785,7 +2786,7 @@ mod tests {
let sql = "SELECT order_id oid, MAX(order_id) OVER () max_oid from orders";
let expected = "\
Projection: #order_id AS oid, #MAX(order_id) AS max_oid\
- \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#order_id)]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}
@@ -2795,7 +2796,7 @@ mod tests {
let sql = "SELECT order_id, MAX(qty * 1.1) OVER () from orders";
let expected = "\
Projection: #order_id, #MAX(qty Multiply Float64(1.1))\
- \n WindowAggr: windowExpr=[[MAX(#qty Multiply Float64(1.1))]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#qty Multiply Float64(1.1))]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}
@@ -2806,20 +2807,29 @@ mod tests {
"SELECT order_id, MAX(qty) OVER (), min(qty) over (), aVg(qty) OVER () from orders";
let expected = "\
Projection: #order_id, #MAX(qty), #MIN(qty), #AVG(qty)\
- \n WindowAggr: windowExpr=[[MAX(#qty), MIN(#qty), AVG(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#qty), MIN(#qty), AVG(#qty)]]\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}
+ /// psql result
+ /// ```
+ /// QUERY PLAN
+ /// ----------------------------------------------------------------------
+ /// WindowAgg (cost=69.83..87.33 rows=1000 width=8)
+ /// -> Sort (cost=69.83..72.33 rows=1000 width=8)
+ /// Sort Key: order_id
+ /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8)
+ /// ```
#[test]
- fn over_partition_by_not_supported() {
- let sql =
- "SELECT order_id, MAX(delivered) OVER (PARTITION BY order_id) from orders";
- let err = logical_plan(sql).expect_err("query should have failed");
- assert_eq!(
- "NotImplemented(\"Unsupported OVER clause (PARTITION BY order_id)\")",
- format!("{:?}", err)
- );
+ fn over_partition_by() {
+ let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id) from orders";
+ let expected = "\
+ Projection: #order_id, #MAX(qty)\
+ \n WindowAggr: windowExpr=[[MAX(#qty)]]\
+ \n Sort: #order_id ASC NULLS FIRST\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
}
/// psql result
@@ -2839,9 +2849,9 @@ mod tests {
let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY order_id DESC) from orders";
let expected = "\
Projection: #order_id, #MAX(qty), #MIN(qty)\
- \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#qty)]]\
\n Sort: #order_id ASC NULLS FIRST\
- \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]]\
\n Sort: #order_id DESC NULLS FIRST\
\n TableScan: orders projection=None";
quick_test(sql, expected);
@@ -2852,9 +2862,9 @@ mod tests {
let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE BETWEEN 3 PRECEDING and 3 FOLLOWING), MIN(qty) OVER (ORDER BY order_id DESC) from orders";
let expected = "\
Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING, #MIN(qty)\
- \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING]]\
\n Sort: #order_id ASC NULLS FIRST\
- \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]]\
\n Sort: #order_id DESC NULLS FIRST\
\n TableScan: orders projection=None";
quick_test(sql, expected);
@@ -2865,9 +2875,9 @@ mod tests {
let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders";
let expected = "\
Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\
- \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW]]\
\n Sort: #order_id ASC NULLS FIRST\
- \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]]\
\n Sort: #order_id DESC NULLS FIRST\
\n TableScan: orders projection=None";
quick_test(sql, expected);
@@ -2878,9 +2888,9 @@ mod tests {
let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders";
let expected = "\
Projection: #order_id, #MAX(qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\
- \n WindowAggr: windowExpr=[[MAX(#qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]]\
\n Sort: #order_id ASC NULLS FIRST\
- \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]]\
\n Sort: #order_id DESC NULLS FIRST\
\n TableScan: orders projection=None";
quick_test(sql, expected);
@@ -2903,9 +2913,9 @@ mod tests {
let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY (order_id + 1)) from orders";
let expected = "\
Projection: #order_id, #MAX(qty), #MIN(qty)\
- \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MAX(#qty)]]\
\n Sort: #order_id ASC NULLS FIRST\
- \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]]\
\n Sort: #order_id Plus Int64(1) ASC NULLS FIRST\
\n TableScan: orders projection=None";
quick_test(sql, expected);
@@ -2929,10 +2939,10 @@ mod tests {
let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY qty, order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders";
let expected = "\
Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\
- \n WindowAggr: windowExpr=[[SUM(#qty)]] partitionBy=[]\
- \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[SUM(#qty)]]\
+ \n WindowAggr: windowExpr=[[MAX(#qty)]]\
\n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\
- \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]]\
\n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\
\n TableScan: orders projection=None";
quick_test(sql, expected);
@@ -2956,10 +2966,10 @@ mod tests {
let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders";
let expected = "\
Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\
- \n WindowAggr: windowExpr=[[SUM(#qty)]] partitionBy=[]\
- \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[SUM(#qty)]]\
+ \n WindowAggr: windowExpr=[[MAX(#qty)]]\
\n Sort: #order_id ASC NULLS FIRST\
- \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]]\
\n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\
\n TableScan: orders projection=None";
quick_test(sql, expected);
@@ -2987,15 +2997,108 @@ mod tests {
let expected = "\
Sort: #order_id ASC NULLS FIRST\
\n Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\
- \n WindowAggr: windowExpr=[[SUM(#qty)]] partitionBy=[]\
- \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[SUM(#qty)]]\
+ \n WindowAggr: windowExpr=[[MAX(#qty)]]\
\n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\
- \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]]\
\n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}
+ /// psql result
+ /// ```
+ /// QUERY PLAN
+ /// ----------------------------------------------------------------------
+ /// WindowAgg (cost=69.83..89.83 rows=1000 width=12)
+ /// -> Sort (cost=69.83..72.33 rows=1000 width=8)
+ /// Sort Key: order_id, qty
+ /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8)
+ /// ```
+ #[test]
+ fn over_partition_by_order_by() {
+ let sql =
+ "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id ORDER BY qty) from orders";
+ let expected = "\
+ Projection: #order_id, #MAX(qty)\
+ \n WindowAggr: windowExpr=[[MAX(#qty)]]\
+ \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
+ }
+
+ /// psql result
+ /// ```
+ /// QUERY PLAN
+ /// ----------------------------------------------------------------------
+ /// WindowAgg (cost=69.83..89.83 rows=1000 width=12)
+ /// -> Sort (cost=69.83..72.33 rows=1000 width=8)
+ /// Sort Key: order_id, qty
+ /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8)
+ /// ```
+ #[test]
+ fn over_partition_by_order_by_no_dup() {
+ let sql =
+ "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id, qty ORDER BY qty) from orders";
+ let expected = "\
+ Projection: #order_id, #MAX(qty)\
+ \n WindowAggr: windowExpr=[[MAX(#qty)]]\
+ \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
+ }
+
+ /// psql result
+ /// ```
+ /// QUERY PLAN
+ /// ----------------------------------------------------------------------------------
+ /// WindowAgg (cost=142.16..162.16 rows=1000 width=16)
+ /// -> Sort (cost=142.16..144.66 rows=1000 width=12)
+ /// Sort Key: qty, order_id
+ /// -> WindowAgg (cost=69.83..92.33 rows=1000 width=12)
+ /// -> Sort (cost=69.83..72.33 rows=1000 width=8)
+ /// Sort Key: order_id, qty
+ /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8)
+ /// ```
+ #[test]
+ fn over_partition_by_order_by_mix_up() {
+ let sql =
+ "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id, qty ORDER BY qty), MIN(qty) OVER (PARTITION BY qty ORDER BY order_id) from orders";
+ let expected = "\
+ Projection: #order_id, #MAX(qty), #MIN(qty)\
+ \n WindowAggr: windowExpr=[[MAX(#qty)]]\
+ \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]]\
+ \n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
+ }
+
+ /// psql result
+ /// ```
+ /// QUERY PLAN
+ /// -----------------------------------------------------------------------------
+ /// WindowAgg (cost=69.83..109.83 rows=1000 width=24)
+ /// -> WindowAgg (cost=69.83..92.33 rows=1000 width=20)
+ /// -> Sort (cost=69.83..72.33 rows=1000 width=16)
+ /// Sort Key: order_id, qty, price
+ /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=16)
+ /// ```
+ /// FIXME: for now we are not detecting prefix of sorting keys in order to save one sort exec phase
+ #[test]
+ fn over_partition_by_order_by_mix_up_prefix() {
+ let sql =
+ "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id ORDER BY qty), MIN(qty) OVER (PARTITION BY order_id, qty ORDER BY price) from orders";
+ let expected = "\
+ Projection: #order_id, #MAX(qty), #MIN(qty)\
+ \n WindowAggr: windowExpr=[[MAX(#qty)]]\
+ \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]]\
+ \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST, #price ASC NULLS FIRST\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
+ }
+
#[test]
fn only_union_all_supported() {
let sql = "SELECT order_id from orders EXCEPT SELECT order_id FROM orders";
diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs
index 848fb3e..5e9b952 100644
--- a/datafusion/src/sql/utils.rs
+++ b/datafusion/src/sql/utils.rs
@@ -239,6 +239,7 @@ where
Expr::WindowFunction {
fun,
args,
+ partition_by,
order_by,
window_frame,
} => Ok(Expr::WindowFunction {
@@ -247,6 +248,10 @@ where
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<_>>>()?,
+ partition_by: partition_by
+ .iter()
+ .map(|e| clone_with_replacement(e, replacement_fn))
+ .collect::<Result<Vec<_>>>()?,
order_by: order_by
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
@@ -432,19 +437,38 @@ pub(crate) fn resolve_aliases_to_exprs(
})
}
+type WindowSortKey = Vec<Expr>;
+
+fn generate_sort_key(partition_by: &[Expr], order_by: &[Expr]) -> WindowSortKey {
+ let mut sort_key = vec![];
+ partition_by.iter().for_each(|e| {
+ let e = e.clone().sort(true, true);
+ if !sort_key.contains(&e) {
+ sort_key.push(e);
+ }
+ });
+ order_by.iter().for_each(|e| {
+ if !sort_key.contains(&e) {
+ sort_key.push(e.clone());
+ }
+ });
+ sort_key
+}
+
/// group a slice of window expression expr by their order by expressions
pub(crate) fn group_window_expr_by_sort_keys(
window_expr: &[Expr],
-) -> Result<Vec<(&[Expr], Vec<&Expr>)>> {
+) -> Result<Vec<(WindowSortKey, Vec<&Expr>)>> {
let mut result = vec![];
window_expr.iter().try_for_each(|expr| match expr {
- Expr::WindowFunction { order_by, .. } => {
+ Expr::WindowFunction { partition_by, order_by, .. } => {
+ let sort_key = generate_sort_key(partition_by, order_by);
if let Some((_, values)) = result.iter_mut().find(
- |group: &&mut (&[Expr], Vec<&Expr>)| matches!(group, (key, _) if key == order_by),
+ |group: &&mut (WindowSortKey, Vec<&Expr>)| matches!(group, (key, _) if *key == sort_key),
) {
values.push(expr);
} else {
- result.push((order_by, vec![expr]))
+ result.push((sort_key, vec![expr]))
}
Ok(())
}
@@ -466,7 +490,7 @@ mod tests {
#[test]
fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
let result = group_window_expr_by_sort_keys(&[])?;
- let expected: Vec<(&[Expr], Vec<&Expr>)> = vec![];
+ let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![];
assert_eq!(expected, result);
Ok(())
}
@@ -476,32 +500,35 @@ mod tests {
let max1 = Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
args: vec![col("name")],
+ partition_by: vec![],
order_by: vec![],
window_frame: None,
};
let max2 = Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
args: vec![col("name")],
+ partition_by: vec![],
order_by: vec![],
window_frame: None,
};
let min3 = Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Min),
args: vec![col("name")],
+ partition_by: vec![],
order_by: vec![],
window_frame: None,
};
let sum4 = Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
args: vec![col("age")],
+ partition_by: vec![],
order_by: vec![],
window_frame: None,
};
- // FIXME use as_ref
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs)?;
- let key = &[];
- let expected: Vec<(&[Expr], Vec<&Expr>)> =
+ let key = vec![];
+ let expected: Vec<(WindowSortKey, Vec<&Expr>)> =
vec![(key, vec![&max1, &max2, &min3, &sum4])];
assert_eq!(expected, result);
Ok(())
@@ -527,24 +554,28 @@ mod tests {
let max1 = Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
args: vec![col("name")],
+ partition_by: vec![],
order_by: vec![age_asc.clone(), name_desc.clone()],
window_frame: None,
};
let max2 = Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
args: vec![col("name")],
+ partition_by: vec![],
order_by: vec![],
window_frame: None,
};
let min3 = Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Min),
args: vec![col("name")],
+ partition_by: vec![],
order_by: vec![age_asc.clone(), name_desc.clone()],
window_frame: None,
};
let sum4 = Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
args: vec![col("age")],
+ partition_by: vec![],
order_by: vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
window_frame: None,
};
@@ -552,11 +583,11 @@ mod tests {
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs)?;
- let key1 = &[age_asc.clone(), name_desc.clone()];
- let key2 = &[];
- let key3 = &[name_desc, age_asc, created_at_desc];
+ let key1 = vec![age_asc.clone(), name_desc.clone()];
+ let key2 = vec![];
+ let key3 = vec![name_desc, age_asc, created_at_desc];
- let expected: Vec<(&[Expr], Vec<&Expr>)> = vec![
+ let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![
(key1, vec![&max1, &min3]),
(key2, vec![&max2]),
(key3, vec![&sum4]),
@@ -571,6 +602,7 @@ mod tests {
Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
args: vec![col("name")],
+ partition_by: vec![],
order_by: vec![
Expr::Sort {
expr: Box::new(col("age")),
@@ -588,6 +620,7 @@ mod tests {
Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
args: vec![col("age")],
+ partition_by: vec![],
order_by: vec![
Expr::Sort {
expr: Box::new(col("name")),