You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/05/21 10:08:03 UTC
[arrow-datafusion] branch master updated: Add window expression
part 1 - logical and physical planning, structure, to/from proto,
and explain, for empty over clause only (#334)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new db4f098 Add window expression part 1 - logical and physical planning, structure, to/from proto, and explain, for empty over clause only (#334)
db4f098 is described below
commit db4f098d38993b96ce1134c4bc7bf5c6579509cf
Author: Jiayu Liu <Ji...@users.noreply.github.com>
AuthorDate: Fri May 21 18:07:56 2021 +0800
Add window expression part 1 - logical and physical planning, structure, to/from proto, and explain, for empty over clause only (#334)
* add window expr
* fix unused imports
* fix clippy
* fix unit test
* Update datafusion/src/logical_plan/builder.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* Update datafusion/src/logical_plan/builder.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* Update datafusion/src/physical_plan/window_functions.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* Update datafusion/src/physical_plan/window_functions.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* adding more built-in functions
* adding filter by todo
* enrich unit test
* update
* add more tests
* fix test
* fix unit test
* fix error
* fix unit test
* fix unit test
* use upper case
* fix unit test
* comment out test
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
ballista/rust/core/proto/ballista.proto | 80 ++++-
.../rust/core/src/serde/logical_plan/from_proto.rs | 197 +++++++++++-
.../rust/core/src/serde/logical_plan/to_proto.rs | 126 +++++++-
.../core/src/serde/physical_plan/from_proto.rs | 81 ++++-
ballista/rust/scheduler/src/planner.rs | 8 +
datafusion/src/logical_plan/builder.rs | 57 +++-
datafusion/src/logical_plan/expr.rs | 33 +-
datafusion/src/logical_plan/plan.rs | 66 +++-
datafusion/src/optimizer/constant_folding.rs | 1 +
datafusion/src/optimizer/hash_build_probe_order.rs | 5 +
datafusion/src/optimizer/projection_push_down.rs | 55 ++++
datafusion/src/optimizer/utils.rs | 23 ++
datafusion/src/physical_plan/aggregates.rs | 3 +-
datafusion/src/physical_plan/mod.rs | 19 ++
datafusion/src/physical_plan/planner.rs | 67 +++-
datafusion/src/physical_plan/sort.rs | 1 +
datafusion/src/physical_plan/window_functions.rs | 342 +++++++++++++++++++++
datafusion/src/physical_plan/windows.rs | 195 ++++++++++++
datafusion/src/sql/planner.rs | 211 +++++++++----
datafusion/src/sql/utils.rs | 15 +
datafusion/tests/sql.rs | 15 +
21 files changed, 1498 insertions(+), 102 deletions(-)
diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto
index 3da0e85..da0c615 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -39,7 +39,6 @@ message LogicalExprNode {
ScalarValue literal = 3;
-
// binary expressions
BinaryExprNode binary_expr = 4;
@@ -60,6 +59,9 @@ message LogicalExprNode {
bool wildcard = 15;
ScalarFunctionNode scalar_function = 16;
TryCastNode try_cast = 17;
+
+ // window expressions
+ WindowExprNode window_expr = 18;
}
}
@@ -151,6 +153,29 @@ message AggregateExprNode {
LogicalExprNode expr = 2;
}
+enum BuiltInWindowFunction {
+ ROW_NUMBER = 0;
+ RANK = 1;
+ DENSE_RANK = 2;
+ PERCENT_RANK = 3;
+ CUME_DIST = 4;
+ NTILE = 5;
+ LAG = 6;
+ LEAD = 7;
+ FIRST_VALUE = 8;
+ LAST_VALUE = 9;
+ NTH_VALUE = 10;
+}
+
+message WindowExprNode {
+ oneof window_function {
+ AggregateFunction aggr_function = 1;
+ BuiltInWindowFunction built_in_function = 2;
+ // udaf = 3
+ }
+ LogicalExprNode expr = 4;
+}
+
message BetweenNode {
LogicalExprNode expr = 1;
bool negated = 2;
@@ -200,6 +225,7 @@ message LogicalPlanNode {
EmptyRelationNode empty_relation = 10;
CreateExternalTableNode create_external_table = 11;
ExplainNode explain = 12;
+ WindowNode window = 13;
}
}
@@ -288,6 +314,50 @@ message AggregateNode {
repeated LogicalExprNode aggr_expr = 3;
}
+message WindowNode {
+ LogicalPlanNode input = 1;
+ repeated LogicalExprNode window_expr = 2;
+ repeated LogicalExprNode partition_by_expr = 3;
+ repeated LogicalExprNode order_by_expr = 4;
+ // "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
+ // this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
+ oneof window_frame {
+ WindowFrame frame = 5;
+ }
+ // TODO add filter by expr
+}
+
+enum WindowFrameUnits {
+ ROWS = 0;
+ RANGE = 1;
+ GROUPS = 2;
+}
+
+message WindowFrame {
+ WindowFrameUnits window_frame_units = 1;
+ WindowFrameBound start_bound = 2;
+ // "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
+ // this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
+ oneof end_bound {
+ WindowFrameBound bound = 3;
+ }
+}
+
+enum WindowFrameBoundType {
+ CURRENT_ROW = 0;
+ PRECEDING = 1;
+ FOLLOWING = 2;
+}
+
+message WindowFrameBound {
+ WindowFrameBoundType window_frame_bound_type = 1;
+ // "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430)
+ // this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3)
+ oneof bound_value {
+ uint64 value = 2;
+ }
+}
+
enum JoinType {
INNER = 0;
LEFT = 1;
@@ -334,6 +404,7 @@ message PhysicalPlanNode {
MergeExecNode merge = 14;
UnresolvedShuffleExecNode unresolved = 15;
RepartitionExecNode repartition = 16;
+ WindowAggExecNode window = 17;
}
}
@@ -399,6 +470,13 @@ enum AggregateMode {
FINAL_PARTITIONED = 2;
}
+message WindowAggExecNode {
+ PhysicalPlanNode input = 1;
+ repeated LogicalExprNode window_expr = 2;
+ repeated string window_expr_name = 3;
+ Schema input_schema = 4;
+}
+
message HashAggregateExecNode {
repeated LogicalExprNode group_expr = 1;
repeated LogicalExprNode aggr_expr = 2;
diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
index 6987035..020858f 100644
--- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
@@ -17,15 +17,15 @@
//! Serde code to convert from protocol buffers to Rust data structures.
+use crate::error::BallistaError;
+use crate::serde::{proto_error, protobuf};
+use crate::{convert_box_required, convert_required};
+use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use std::{
convert::{From, TryInto},
unimplemented,
};
-use crate::error::BallistaError;
-use crate::serde::{proto_error, protobuf};
-use crate::{convert_box_required, convert_required};
-
use arrow::datatypes::{DataType, Field, Schema};
use datafusion::logical_plan::{
abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin,
@@ -33,6 +33,7 @@ use datafusion::logical_plan::{
};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::physical_plan::csv::CsvReadOptions;
+use datafusion::physical_plan::window_functions::BuiltInWindowFunction;
use datafusion::scalar::ScalarValue;
use protobuf::logical_plan_node::LogicalPlanType;
use protobuf::{logical_expr_node::ExprType, scalar_type};
@@ -75,6 +76,34 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
.build()
.map_err(|e| e.into())
}
+ LogicalPlanType::Window(window) => {
+ let input: LogicalPlan = convert_box_required!(window.input)?;
+ let window_expr = window
+ .window_expr
+ .iter()
+ .map(|expr| expr.try_into())
+ .collect::<Result<Vec<_>, _>>()?;
+
+ // let partition_by_expr = window
+ // .partition_by_expr
+ // .iter()
+ // .map(|expr| expr.try_into())
+ // .collect::<Result<Vec<_>, _>>()?;
+ // let order_by_expr = window
+ // .order_by_expr
+ // .iter()
+ // .map(|expr| expr.try_into())
+ // .collect::<Result<Vec<_>, _>>()?;
+ // // FIXME: add filter by expr
+ // // FIXME: parse the window_frame data
+ // let window_frame = None;
+ LogicalPlanBuilder::from(&input)
+ .window(
+ window_expr, /* filter_by_expr, partition_by_expr, order_by_expr, window_frame*/
+ )?
+ .build()
+ .map_err(|e| e.into())
+ }
LogicalPlanType::Aggregate(aggregate) => {
let input: LogicalPlan = convert_box_required!(aggregate.input)?;
let group_expr = aggregate
@@ -871,7 +900,10 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
type Error = BallistaError;
fn try_into(self) -> Result<Expr, Self::Error> {
+ use datafusion::physical_plan::window_functions;
use protobuf::logical_expr_node::ExprType;
+ use protobuf::window_expr_node;
+ use protobuf::WindowExprNode;
let expr_type = self
.expr_type
@@ -889,6 +921,48 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
let scalar_value: datafusion::scalar::ScalarValue = literal.try_into()?;
Ok(Expr::Literal(scalar_value))
}
+ ExprType::WindowExpr(expr) => {
+ let window_function = expr
+ .window_function
+ .as_ref()
+ .ok_or_else(|| proto_error("Received empty window function"))?;
+ match window_function {
+ window_expr_node::WindowFunction::AggrFunction(i) => {
+ let aggr_function = protobuf::AggregateFunction::from_i32(*i)
+ .ok_or_else(|| {
+ proto_error(format!(
+ "Received an unknown aggregate window function: {}",
+ i
+ ))
+ })?;
+
+ Ok(Expr::WindowFunction {
+ fun: window_functions::WindowFunction::AggregateFunction(
+ AggregateFunction::from(aggr_function),
+ ),
+ args: vec![parse_required_expr(&expr.expr)?],
+ })
+ }
+ window_expr_node::WindowFunction::BuiltInFunction(i) => {
+ let built_in_function =
+ protobuf::BuiltInWindowFunction::from_i32(*i).ok_or_else(
+ || {
+ proto_error(format!(
+ "Received an unknown built-in window function: {}",
+ i
+ ))
+ },
+ )?;
+
+ Ok(Expr::WindowFunction {
+ fun: window_functions::WindowFunction::BuiltInWindowFunction(
+ BuiltInWindowFunction::from(built_in_function),
+ ),
+ args: vec![parse_required_expr(&expr.expr)?],
+ })
+ }
+ }
+ }
ExprType::AggregateExpr(expr) => {
let aggr_function =
protobuf::AggregateFunction::from_i32(expr.aggr_function)
@@ -898,13 +972,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
expr.aggr_function
))
})?;
- let fun = match aggr_function {
- protobuf::AggregateFunction::Min => AggregateFunction::Min,
- protobuf::AggregateFunction::Max => AggregateFunction::Max,
- protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
- protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
- protobuf::AggregateFunction::Count => AggregateFunction::Count,
- };
+ let fun = AggregateFunction::from(aggr_function);
Ok(Expr::AggregateFunction {
fun,
@@ -1152,6 +1220,7 @@ impl TryInto<arrow::datatypes::Field> for &protobuf::Field {
}
use datafusion::physical_plan::datetime_expressions::{date_trunc, to_timestamp};
+use datafusion::physical_plan::{aggregates, windows};
use datafusion::prelude::{
array, length, lower, ltrim, md5, rtrim, sha224, sha256, sha384, sha512, trim, upper,
};
@@ -1202,3 +1271,109 @@ fn parse_optional_expr(
None => Ok(None),
}
}
+
+impl From<protobuf::WindowFrameUnits> for WindowFrameUnits {
+ fn from(units: protobuf::WindowFrameUnits) -> Self {
+ match units {
+ protobuf::WindowFrameUnits::Rows => WindowFrameUnits::Rows,
+ protobuf::WindowFrameUnits::Range => WindowFrameUnits::Range,
+ protobuf::WindowFrameUnits::Groups => WindowFrameUnits::Groups,
+ }
+ }
+}
+
+impl TryFrom<protobuf::WindowFrameBound> for WindowFrameBound {
+ type Error = BallistaError;
+
+ fn try_from(bound: protobuf::WindowFrameBound) -> Result<Self, Self::Error> {
+ let bound_type = protobuf::WindowFrameBoundType::from_i32(bound.window_frame_bound_type).ok_or_else(|| {
+ proto_error(format!(
+ "Received a WindowFrameBound message with unknown WindowFrameBoundType {}",
+ bound.window_frame_bound_type
+ ))
+ })?;
+ match bound_type {
+ protobuf::WindowFrameBoundType::CurrentRow => {
+ Ok(WindowFrameBound::CurrentRow)
+ }
+ protobuf::WindowFrameBoundType::Preceding => {
+ // FIXME implement bound value parsing
+ Ok(WindowFrameBound::Preceding(Some(1)))
+ }
+ protobuf::WindowFrameBoundType::Following => {
+ // FIXME implement bound value parsing
+ Ok(WindowFrameBound::Following(Some(1)))
+ }
+ }
+ }
+}
+
+impl TryFrom<protobuf::WindowFrame> for WindowFrame {
+ type Error = BallistaError;
+
+ fn try_from(window: protobuf::WindowFrame) -> Result<Self, Self::Error> {
+ let units = protobuf::WindowFrameUnits::from_i32(window.window_frame_units)
+ .ok_or_else(|| {
+ proto_error(format!(
+ "Received a WindowFrame message with unknown WindowFrameUnits {}",
+ window.window_frame_units
+ ))
+ })?
+ .into();
+ let start_bound = window
+ .start_bound
+ .ok_or_else(|| {
+ proto_error(
+ "Received a WindowFrame message with no start_bound".to_owned(),
+ )
+ })?
+ .try_into()?;
+ // FIXME parse end bound
+ let end_bound = None;
+ Ok(WindowFrame {
+ units,
+ start_bound,
+ end_bound,
+ })
+ }
+}
+
+impl From<protobuf::AggregateFunction> for AggregateFunction {
+ fn from(aggr_function: protobuf::AggregateFunction) -> Self {
+ match aggr_function {
+ protobuf::AggregateFunction::Min => AggregateFunction::Min,
+ protobuf::AggregateFunction::Max => AggregateFunction::Max,
+ protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
+ protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
+ protobuf::AggregateFunction::Count => AggregateFunction::Count,
+ }
+ }
+}
+
+impl From<protobuf::BuiltInWindowFunction> for BuiltInWindowFunction {
+ fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self {
+ match built_in_function {
+ protobuf::BuiltInWindowFunction::RowNumber => {
+ BuiltInWindowFunction::RowNumber
+ }
+ protobuf::BuiltInWindowFunction::Rank => BuiltInWindowFunction::Rank,
+ protobuf::BuiltInWindowFunction::PercentRank => {
+ BuiltInWindowFunction::PercentRank
+ }
+ protobuf::BuiltInWindowFunction::DenseRank => {
+ BuiltInWindowFunction::DenseRank
+ }
+ protobuf::BuiltInWindowFunction::Lag => BuiltInWindowFunction::Lag,
+ protobuf::BuiltInWindowFunction::Lead => BuiltInWindowFunction::Lead,
+ protobuf::BuiltInWindowFunction::FirstValue => {
+ BuiltInWindowFunction::FirstValue
+ }
+ protobuf::BuiltInWindowFunction::CumeDist => BuiltInWindowFunction::CumeDist,
+ protobuf::BuiltInWindowFunction::Ntile => BuiltInWindowFunction::Ntile,
+ protobuf::BuiltInWindowFunction::NthValue => BuiltInWindowFunction::NthValue,
+ protobuf::BuiltInWindowFunction::LastValue => {
+ BuiltInWindowFunction::LastValue
+ }
+ }
+ }
+}
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index 01b669d..47e2748 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -26,16 +26,19 @@ use std::{
use crate::datasource::DfTableAdapter;
use crate::serde::{protobuf, BallistaError};
-
use arrow::datatypes::{DataType, Schema};
use datafusion::datasource::CsvFile;
use datafusion::logical_plan::{Expr, JoinType, LogicalPlan};
use datafusion::physical_plan::aggregates::AggregateFunction;
+use datafusion::physical_plan::window_functions::{
+ BuiltInWindowFunction, WindowFunction,
+};
use datafusion::{datasource::parquet::ParquetTable, logical_plan::exprlist_to_fields};
use protobuf::{
arrow_type, logical_expr_node::ExprType, scalar_type, DateUnit, Field,
PrimitiveScalarType, ScalarListValue, ScalarType,
};
+use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use super::super::proto_error;
use datafusion::physical_plan::functions::BuiltinScalarFunction;
@@ -772,6 +775,43 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
))),
})
}
+ LogicalPlan::Window {
+ input,
+ window_expr,
+ // FIXME implement next
+ // filter_by_expr,
+ // FIXME implement next
+ // partition_by_expr,
+ // FIXME implement next
+ // order_by_expr,
+ // FIXME implement next
+ // window_frame,
+ ..
+ } => {
+ let input: protobuf::LogicalPlanNode = input.as_ref().try_into()?;
+ // FIXME: implement
+ // let filter_by_expr = vec![];
+ // FIXME: implement
+ let partition_by_expr = vec![];
+ // FIXME: implement
+ let order_by_expr = vec![];
+ // FIXME: implement
+ let window_frame = None;
+ Ok(protobuf::LogicalPlanNode {
+ logical_plan_type: Some(LogicalPlanType::Window(Box::new(
+ protobuf::WindowNode {
+ input: Some(Box::new(input)),
+ window_expr: window_expr
+ .iter()
+ .map(|expr| expr.try_into())
+ .collect::<Result<Vec<_>, BallistaError>>()?,
+ partition_by_expr,
+ order_by_expr,
+ window_frame,
+ },
+ ))),
+ })
+ }
LogicalPlan::Aggregate {
input,
group_expr,
@@ -997,6 +1037,30 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
expr_type: Some(ExprType::BinaryExpr(binary_expr)),
})
}
+ Expr::WindowFunction {
+ ref fun, ref args, ..
+ } => {
+ let window_function = match fun {
+ WindowFunction::AggregateFunction(fun) => {
+ protobuf::window_expr_node::WindowFunction::AggrFunction(
+ protobuf::AggregateFunction::from(fun).into(),
+ )
+ }
+ WindowFunction::BuiltInWindowFunction(fun) => {
+ protobuf::window_expr_node::WindowFunction::BuiltInFunction(
+ protobuf::BuiltInWindowFunction::from(fun).into(),
+ )
+ }
+ };
+ let arg = &args[0];
+ let window_expr = Box::new(protobuf::WindowExprNode {
+ expr: Some(Box::new(arg.try_into()?)),
+ window_function: Some(window_function),
+ });
+ Ok(protobuf::LogicalExprNode {
+ expr_type: Some(ExprType::WindowExpr(window_expr)),
+ })
+ }
Expr::AggregateFunction {
ref fun, ref args, ..
} => {
@@ -1178,6 +1242,66 @@ impl Into<protobuf::Schema> for &Schema {
}
}
+impl From<&AggregateFunction> for protobuf::AggregateFunction {
+ fn from(value: &AggregateFunction) -> Self {
+ match value {
+ AggregateFunction::Min => Self::Min,
+ AggregateFunction::Max => Self::Max,
+ AggregateFunction::Sum => Self::Sum,
+ AggregateFunction::Avg => Self::Avg,
+ AggregateFunction::Count => Self::Count,
+ }
+ }
+}
+
+impl From<&BuiltInWindowFunction> for protobuf::BuiltInWindowFunction {
+ fn from(value: &BuiltInWindowFunction) -> Self {
+ match value {
+ BuiltInWindowFunction::FirstValue => Self::FirstValue,
+ BuiltInWindowFunction::LastValue => Self::LastValue,
+ BuiltInWindowFunction::NthValue => Self::NthValue,
+ BuiltInWindowFunction::Ntile => Self::Ntile,
+ BuiltInWindowFunction::CumeDist => Self::CumeDist,
+ BuiltInWindowFunction::PercentRank => Self::PercentRank,
+ BuiltInWindowFunction::RowNumber => Self::RowNumber,
+ BuiltInWindowFunction::Rank => Self::Rank,
+ BuiltInWindowFunction::Lag => Self::Lag,
+ BuiltInWindowFunction::Lead => Self::Lead,
+ BuiltInWindowFunction::DenseRank => Self::DenseRank,
+ }
+ }
+}
+
+impl From<WindowFrameUnits> for protobuf::WindowFrameUnits {
+ fn from(units: WindowFrameUnits) -> Self {
+ match units {
+ WindowFrameUnits::Rows => protobuf::WindowFrameUnits::Rows,
+ WindowFrameUnits::Range => protobuf::WindowFrameUnits::Range,
+ WindowFrameUnits::Groups => protobuf::WindowFrameUnits::Groups,
+ }
+ }
+}
+
+impl TryFrom<WindowFrameBound> for protobuf::WindowFrameBound {
+ type Error = BallistaError;
+
+ fn try_from(_bound: WindowFrameBound) -> Result<Self, Self::Error> {
+ Err(BallistaError::NotImplemented(
+ "WindowFrameBound => protobuf::WindowFrameBound".to_owned(),
+ ))
+ }
+}
+
+impl TryFrom<WindowFrame> for protobuf::WindowFrame {
+ type Error = BallistaError;
+
+ fn try_from(_window: WindowFrame) -> Result<Self, Self::Error> {
+ Err(BallistaError::NotImplemented(
+ "WindowFrame => protobuf::WindowFrame".to_owned(),
+ ))
+ }
+}
+
impl TryFrom<&arrow::datatypes::DataType> for protobuf::ScalarType {
type Error = BallistaError;
fn try_from(value: &arrow::datatypes::DataType) -> Result<Self, Self::Error> {
diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
index 97f0394..d034f3c 100644
--- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
@@ -28,7 +28,6 @@ use crate::serde::protobuf::LogicalExprNode;
use crate::serde::scheduler::PartitionLocation;
use crate::serde::{proto_error, protobuf};
use crate::{convert_box_required, convert_required};
-
use arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::catalog::catalog::{
CatalogList, CatalogProvider, MemoryCatalogList, MemoryCatalogProvider,
@@ -43,6 +42,11 @@ use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec
use datafusion::physical_plan::hash_join::PartitionMode;
use datafusion::physical_plan::merge::MergeExec;
use datafusion::physical_plan::planner::DefaultPhysicalPlanner;
+use datafusion::physical_plan::window_functions::{
+ BuiltInWindowFunction, WindowFunction,
+};
+use datafusion::physical_plan::windows::create_window_expr;
+use datafusion::physical_plan::windows::WindowAggExec;
use datafusion::physical_plan::{
coalesce_batches::CoalesceBatchesExec,
csv::CsvExec,
@@ -58,7 +62,7 @@ use datafusion::physical_plan::{
sort::{SortExec, SortOptions},
Partitioning,
};
-use datafusion::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr};
+use datafusion::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr};
use datafusion::prelude::CsvReadOptions;
use log::debug;
use protobuf::logical_expr_node::ExprType;
@@ -189,6 +193,77 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
let input: Arc<dyn ExecutionPlan> = convert_box_required!(limit.input)?;
Ok(Arc::new(LocalLimitExec::new(input, limit.limit as usize)))
}
+ PhysicalPlanType::Window(window_agg) => {
+ let input: Arc<dyn ExecutionPlan> =
+ convert_box_required!(window_agg.input)?;
+ let input_schema = window_agg
+ .input_schema
+ .as_ref()
+ .ok_or_else(|| {
+ BallistaError::General(
+ "input_schema in WindowAggrNode is missing.".to_owned(),
+ )
+ })?
+ .clone();
+
+ let physical_schema: SchemaRef =
+ SchemaRef::new((&input_schema).try_into()?);
+
+ let catalog_list =
+ Arc::new(MemoryCatalogList::new()) as Arc<dyn CatalogList>;
+ let ctx_state = ExecutionContextState {
+ catalog_list,
+ scalar_functions: Default::default(),
+ var_provider: Default::default(),
+ aggregate_functions: Default::default(),
+ config: ExecutionConfig::new(),
+ execution_props: ExecutionProps::new(),
+ };
+
+ let window_agg_expr: Vec<(Expr, String)> = window_agg
+ .window_expr
+ .iter()
+ .zip(window_agg.window_expr_name.iter())
+ .map(|(expr, name)| expr.try_into().map(|expr| (expr, name.clone())))
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let mut physical_window_expr = vec![];
+
+ let df_planner = DefaultPhysicalPlanner::default();
+
+ for (expr, name) in &window_agg_expr {
+ match expr {
+ Expr::WindowFunction { fun, args } => {
+ let arg = df_planner
+ .create_physical_expr(
+ &args[0],
+ &physical_schema,
+ &ctx_state,
+ )
+ .map_err(|e| {
+ BallistaError::General(format!("{:?}", e))
+ })?;
+ physical_window_expr.push(create_window_expr(
+ &fun,
+ &[arg],
+ &physical_schema,
+ name.to_owned(),
+ )?);
+ }
+ _ => {
+ return Err(BallistaError::General(
+ "Invalid expression for WindowAggrExec".to_string(),
+ ));
+ }
+ }
+ }
+
+ Ok(Arc::new(WindowAggExec::try_new(
+ physical_window_expr,
+ input,
+ Arc::new((&input_schema).try_into()?),
+ )?))
+ }
PhysicalPlanType::HashAggregate(hash_agg) => {
let input: Arc<dyn ExecutionPlan> =
convert_box_required!(hash_agg.input)?;
@@ -222,7 +297,6 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
.map(|(expr, name)| expr.try_into().map(|expr| (expr, name.clone())))
.collect::<Result<Vec<_>, _>>()?;
- let df_planner = DefaultPhysicalPlanner::default();
let catalog_list =
Arc::new(MemoryCatalogList::new()) as Arc<dyn CatalogList>;
let ctx_state = ExecutionContextState {
@@ -248,6 +322,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
let mut physical_aggr_expr = vec![];
+ let df_planner = DefaultPhysicalPlanner::default();
for (expr, name) in &logical_agg_expr {
match expr {
Expr::AggregateFunction { fun, args, .. } => {
diff --git a/ballista/rust/scheduler/src/planner.rs b/ballista/rust/scheduler/src/planner.rs
index 2f01e73..b1d999b 100644
--- a/ballista/rust/scheduler/src/planner.rs
+++ b/ballista/rust/scheduler/src/planner.rs
@@ -35,6 +35,7 @@ use datafusion::physical_optimizer::optimizer::PhysicalOptimizerRule;
use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec};
use datafusion::physical_plan::hash_join::HashJoinExec;
use datafusion::physical_plan::merge::MergeExec;
+use datafusion::physical_plan::windows::WindowAggExec;
use datafusion::physical_plan::ExecutionPlan;
use log::info;
@@ -150,6 +151,13 @@ impl DistributedPlanner {
} else if let Some(join) = execution_plan.as_any().downcast_ref::<HashJoinExec>()
{
Ok((join.with_new_children(children)?, stages))
+ } else if let Some(window) =
+ execution_plan.as_any().downcast_ref::<WindowAggExec>()
+ {
+ Err(BallistaError::NotImplemented(format!(
+ "WindowAggExec with window {:?}",
+ window
+ )))
} else {
// TODO check for compatible partitioning schema, not just count
if execution_plan.output_partitioning().partition_count()
diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs
index 2e69814..9515ac2 100644
--- a/datafusion/src/logical_plan/builder.rs
+++ b/datafusion/src/logical_plan/builder.rs
@@ -24,18 +24,17 @@ use arrow::{
record_batch::RecordBatch,
};
+use super::dfschema::ToDFSchema;
+use super::{
+ col, exprlist_to_fields, Expr, JoinType, LogicalPlan, PlanType, StringifiedPlan,
+};
use crate::datasource::TableProvider;
use crate::error::{DataFusionError, Result};
+use crate::logical_plan::{DFField, DFSchema, DFSchemaRef, Partitioning};
use crate::{
datasource::{empty::EmptyTable, parquet::ParquetTable, CsvFile, MemTable},
prelude::CsvReadOptions,
};
-
-use super::dfschema::ToDFSchema;
-use super::{
- col, exprlist_to_fields, Expr, JoinType, LogicalPlan, PlanType, StringifiedPlan,
-};
-use crate::logical_plan::{DFField, DFSchema, DFSchemaRef, Partitioning};
use std::collections::HashSet;
/// Builder for logical plans
@@ -289,6 +288,52 @@ impl LogicalPlanBuilder {
}))
}
+ /// Apply a window
+ ///
+ /// NOTE: this feature is under development and this API will be changing
+ ///
+ /// - https://github.com/apache/arrow-datafusion/issues/359 basic structure
+ /// - https://github.com/apache/arrow-datafusion/issues/298 empty over clause
+ /// - https://github.com/apache/arrow-datafusion/issues/299 with partition clause
+ /// - https://github.com/apache/arrow-datafusion/issues/360 with order by
+ /// - https://github.com/apache/arrow-datafusion/issues/361 with window frame
+ pub fn window(
+ &self,
+ window_expr: impl IntoIterator<Item = Expr>,
+ // FIXME: implement next
+ // filter_by_expr: impl IntoIterator<Item = Expr>,
+ // FIXME: implement next
+ // partition_by_expr: impl IntoIterator<Item = Expr>,
+ // FIXME: implement next
+ // order_by_expr: impl IntoIterator<Item = Expr>,
+ // FIXME: implement next
+ // window_frame: Option<WindowFrame>,
+ ) -> Result<Self> {
+ let window_expr = window_expr.into_iter().collect::<Vec<Expr>>();
+ // FIXME: implement next
+ // let partition_by_expr = partition_by_expr.into_iter().collect::<Vec<Expr>>();
+ // FIXME: implement next
+ // let order_by_expr = order_by_expr.into_iter().collect::<Vec<Expr>>();
+ let all_expr = window_expr.iter();
+ validate_unique_names("Windows", all_expr.clone(), self.plan.schema())?;
+
+ let mut window_fields: Vec<DFField> =
+ exprlist_to_fields(all_expr, self.plan.schema())?;
+ window_fields.extend_from_slice(self.plan.schema().fields());
+
+ Ok(Self::from(&LogicalPlan::Window {
+ input: Arc::new(self.plan.clone()),
+ // FIXME implement next
+ // partition_by_expr,
+ // FIXME implement next
+ // order_by_expr,
+ // FIXME implement next
+ // window_frame,
+ window_expr,
+ schema: Arc::new(DFSchema::new(window_fields)?),
+ }))
+ }
+
/// Apply an aggregate: grouping on the `group_expr` expressions
/// and calculating `aggr_expr` aggregates for each distinct
/// value of the `group_expr`;
diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs
index 3365bf2..ab02559 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -30,6 +30,7 @@ use crate::error::{DataFusionError, Result};
use crate::logical_plan::{DFField, DFSchema};
use crate::physical_plan::{
aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF,
+ window_functions,
};
use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue};
use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature};
@@ -190,6 +191,13 @@ pub enum Expr {
/// Whether this is a DISTINCT aggregation or not
distinct: bool,
},
+ /// Represents the call of a window function with arguments.
+ WindowFunction {
+ /// Name of the function
+ fun: window_functions::WindowFunction,
+ /// List of expressions to feed to the functions as arguments
+ args: Vec<Expr>,
+ },
/// aggregate function
AggregateUDF {
/// The function
@@ -244,6 +252,13 @@ impl Expr {
.collect::<Result<Vec<_>>>()?;
functions::return_type(fun, &data_types)
}
+ Expr::WindowFunction { fun, args, .. } => {
+ let data_types = args
+ .iter()
+ .map(|e| e.get_type(schema))
+ .collect::<Result<Vec<_>>>()?;
+ window_functions::return_type(fun, &data_types)
+ }
Expr::AggregateFunction { fun, args, .. } => {
let data_types = args
.iter()
@@ -316,6 +331,7 @@ impl Expr {
Expr::TryCast { .. } => Ok(true),
Expr::ScalarFunction { .. } => Ok(true),
Expr::ScalarUDF { .. } => Ok(true),
+ Expr::WindowFunction { .. } => Ok(true),
Expr::AggregateFunction { .. } => Ok(true),
Expr::AggregateUDF { .. } => Ok(true),
Expr::Not(expr) => expr.nullable(input_schema),
@@ -571,6 +587,9 @@ impl Expr {
Expr::ScalarUDF { args, .. } => args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
+ Expr::WindowFunction { args, .. } => args
+ .iter()
+ .try_fold(visitor, |visitor, arg| arg.accept(visitor)),
Expr::AggregateFunction { args, .. } => args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
@@ -704,6 +723,10 @@ impl Expr {
args: rewrite_vec(args, rewriter)?,
fun,
},
+ Expr::WindowFunction { args, fun } => Expr::WindowFunction {
+ args: rewrite_vec(args, rewriter)?,
+ fun,
+ },
Expr::AggregateFunction {
args,
fun,
@@ -1151,7 +1174,7 @@ pub fn create_udf(
}
/// Creates a new UDAF with a specific signature, state type and return type.
-/// The signature and state type must match the `Acumulator's implementation`.
+/// The signature and state type must match the `Accumulator's implementation`.
#[allow(clippy::rc_buffer)]
pub fn create_udaf(
name: &str,
@@ -1245,6 +1268,9 @@ impl fmt::Debug for Expr {
Expr::ScalarUDF { fun, ref args, .. } => {
fmt_function(f, &fun.name, false, args)
}
+ Expr::WindowFunction { fun, ref args, .. } => {
+ fmt_function(f, &fun.to_string(), false, args)
+ }
Expr::AggregateFunction {
fun,
distinct,
@@ -1360,6 +1386,9 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
Expr::ScalarUDF { fun, args, .. } => {
create_function_name(&fun.name, false, args, input_schema)
}
+ Expr::WindowFunction { fun, args } => {
+ create_function_name(&fun.to_string(), false, args, input_schema)
+ }
Expr::AggregateFunction {
fun,
distinct,
@@ -1387,7 +1416,7 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
}
}
other => Err(DataFusionError::NotImplemented(format!(
- "Physical plan does not support logical expression {:?}",
+ "Create name does not support logical expression {:?}",
other
))),
}
diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs
index 8b9aac9..4027916 100644
--- a/datafusion/src/logical_plan/plan.rs
+++ b/datafusion/src/logical_plan/plan.rs
@@ -17,24 +17,21 @@
//! This module contains the `LogicalPlan` enum that describes queries
//! via a logical query plan.
-use std::{
- cmp::min,
- fmt::{self, Display},
- sync::Arc,
-};
-
-use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
-
-use crate::datasource::TableProvider;
-use crate::sql::parser::FileType;
-
use super::expr::Expr;
use super::extension::UserDefinedLogicalNode;
use super::{
col,
display::{GraphvizVisitor, IndentVisitor},
};
+use crate::datasource::TableProvider;
use crate::logical_plan::dfschema::DFSchemaRef;
+use crate::sql::parser::FileType;
+use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+use std::{
+ cmp::min,
+ fmt::{self, Display},
+ sync::Arc,
+};
/// Join type
#[derive(Debug, Clone, Copy)]
@@ -83,6 +80,23 @@ pub enum LogicalPlan {
/// The incoming logical plan
input: Arc<LogicalPlan>,
},
+ /// Window its input based on a set of window spec and window function (e.g. SUM or RANK)
+ Window {
+ /// The incoming logical plan
+ input: Arc<LogicalPlan>,
+ /// The window function expression
+ window_expr: Vec<Expr>,
+ /// Filter by expressions
+ // filter_by_expr: Vec<Expr>,
+ /// Partition by expressions
+ // partition_by_expr: Vec<Expr>,
+ /// Order by expressions
+ // order_by_expr: Vec<Expr>,
+ /// Window Frame
+ // window_frame: Option<WindowFrame>,
+ /// The schema description of the window output
+ schema: DFSchemaRef,
+ },
/// Aggregates its input based on a set of grouping and aggregate
/// expressions (e.g. SUM).
Aggregate {
@@ -211,6 +225,7 @@ impl LogicalPlan {
} => &projected_schema,
LogicalPlan::Projection { schema, .. } => &schema,
LogicalPlan::Filter { input, .. } => input.schema(),
+ LogicalPlan::Window { schema, .. } => &schema,
LogicalPlan::Aggregate { schema, .. } => &schema,
LogicalPlan::Sort { input, .. } => input.schema(),
LogicalPlan::Join { schema, .. } => &schema,
@@ -230,7 +245,8 @@ impl LogicalPlan {
LogicalPlan::TableScan {
projected_schema, ..
} => vec![&projected_schema],
- LogicalPlan::Aggregate { input, schema, .. }
+ LogicalPlan::Window { input, schema, .. }
+ | LogicalPlan::Aggregate { input, schema, .. }
| LogicalPlan::Projection { input, schema, .. } => {
let mut schemas = input.all_schemas();
schemas.insert(0, &schema);
@@ -288,6 +304,16 @@ impl LogicalPlan {
Partitioning::Hash(expr, _) => expr.clone(),
_ => vec![],
},
+ LogicalPlan::Window {
+ window_expr,
+ // FIXME implement next
+ // filter_by_expr,
+ // FIXME implement next
+ // partition_by_expr,
+ // FIXME implement next
+ // order_by_expr,
+ ..
+ } => window_expr.clone(),
LogicalPlan::Aggregate {
group_expr,
aggr_expr,
@@ -322,6 +348,7 @@ impl LogicalPlan {
LogicalPlan::Projection { input, .. } => vec![input],
LogicalPlan::Filter { input, .. } => vec![input],
LogicalPlan::Repartition { input, .. } => vec![input],
+ LogicalPlan::Window { input, .. } => vec![input],
LogicalPlan::Aggregate { input, .. } => vec![input],
LogicalPlan::Sort { input, .. } => vec![input],
LogicalPlan::Join { left, right, .. } => vec![left, right],
@@ -415,6 +442,7 @@ impl LogicalPlan {
LogicalPlan::Projection { input, .. } => input.accept(visitor)?,
LogicalPlan::Filter { input, .. } => input.accept(visitor)?,
LogicalPlan::Repartition { input, .. } => input.accept(visitor)?,
+ LogicalPlan::Window { input, .. } => input.accept(visitor)?,
LogicalPlan::Aggregate { input, .. } => input.accept(visitor)?,
LogicalPlan::Sort { input, .. } => input.accept(visitor)?,
LogicalPlan::Join { left, right, .. }
@@ -667,6 +695,20 @@ impl LogicalPlan {
predicate: ref expr,
..
} => write!(f, "Filter: {:?}", expr),
+ LogicalPlan::Window {
+ ref window_expr,
+ // FIXME implement next
+ // ref partition_by_expr,
+ // FIXME implement next
+ // ref order_by_expr,
+ ..
+ } => {
+ write!(
+ f,
+ "WindowAggr: windowExpr=[{:?}] partitionBy=[], orderBy=[]",
+ window_expr
+ )
+ }
LogicalPlan::Aggregate {
ref group_expr,
ref aggr_expr,
diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs
index 51bf0ce..af89aa1 100644
--- a/datafusion/src/optimizer/constant_folding.rs
+++ b/datafusion/src/optimizer/constant_folding.rs
@@ -71,6 +71,7 @@ impl OptimizerRule for ConstantFolding {
}),
// Rest: recurse into plan, apply optimization where possible
LogicalPlan::Projection { .. }
+ | LogicalPlan::Window { .. }
| LogicalPlan::Aggregate { .. }
| LogicalPlan::Repartition { .. }
| LogicalPlan::CreateExternalTable { .. }
diff --git a/datafusion/src/optimizer/hash_build_probe_order.rs b/datafusion/src/optimizer/hash_build_probe_order.rs
index 168c4a1..100ae4f 100644
--- a/datafusion/src/optimizer/hash_build_probe_order.rs
+++ b/datafusion/src/optimizer/hash_build_probe_order.rs
@@ -54,6 +54,10 @@ fn get_num_rows(logical_plan: &LogicalPlan) -> Option<usize> {
let num_rows_input = get_num_rows(input);
num_rows_input.map(|rows| std::cmp::min(*limit, rows))
}
+ LogicalPlan::Window { input, .. } => {
+ // window functions do not change num of rows
+ get_num_rows(input)
+ }
LogicalPlan::Aggregate { .. } => {
// we cannot yet predict how many rows will be produced by an aggregate because
// we do not know the cardinality of the grouping keys
@@ -172,6 +176,7 @@ impl OptimizerRule for HashBuildProbeOrder {
}
// Rest: recurse into plan, apply optimization where possible
LogicalPlan::Projection { .. }
+ | LogicalPlan::Window { .. }
| LogicalPlan::Aggregate { .. }
| LogicalPlan::TableScan { .. }
| LogicalPlan::Limit { .. }
diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs
index 21c9cab..e47832b 100644
--- a/datafusion/src/optimizer/projection_push_down.rs
+++ b/datafusion/src/optimizer/projection_push_down.rs
@@ -193,6 +193,61 @@ fn optimize_plan(
schema: schema.clone(),
})
}
+ LogicalPlan::Window {
+ schema,
+ window_expr,
+ input,
+ // FIXME implement next
+ // filter_by_expr,
+ // FIXME implement next
+ // partition_by_expr,
+ // FIXME implement next
+ // order_by_expr,
+ // FIXME implement next
+ // window_frame,
+ ..
+ } => {
+ // Gather all columns needed for expressions in this Window
+ let mut new_window_expr = Vec::new();
+ window_expr.iter().try_for_each(|expr| {
+ let name = &expr.name(&schema)?;
+ if required_columns.contains(name) {
+ new_window_expr.push(expr.clone());
+ new_required_columns.insert(name.clone());
+ // add to the new set of required columns
+ utils::expr_to_column_names(expr, &mut new_required_columns)
+ } else {
+ Ok(())
+ }
+ })?;
+
+ let new_schema = DFSchema::new(
+ schema
+ .fields()
+ .iter()
+ .filter(|x| new_required_columns.contains(x.name()))
+ .cloned()
+ .collect(),
+ )?;
+
+ Ok(LogicalPlan::Window {
+ window_expr: new_window_expr,
+ // FIXME implement next
+ // partition_by_expr: partition_by_expr.clone(),
+ // FIXME implement next
+ // order_by_expr: order_by_expr.clone(),
+ // FIXME implement next
+ // window_frame: window_frame.clone(),
+ input: Arc::new(optimize_plan(
+ optimizer,
+ &input,
+ &new_required_columns,
+ true,
+ execution_props,
+ )?),
+ schema: DFSchemaRef::new(new_schema),
+ })
+ }
LogicalPlan::Aggregate {
schema,
input,
diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs
index 9288c65..284ead2 100644
--- a/datafusion/src/optimizer/utils.rs
+++ b/datafusion/src/optimizer/utils.rs
@@ -78,6 +78,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> {
Expr::Sort { .. } => {}
Expr::ScalarFunction { .. } => {}
Expr::ScalarUDF { .. } => {}
+ Expr::WindowFunction { .. } => {}
Expr::AggregateFunction { .. } => {}
Expr::AggregateUDF { .. } => {}
Expr::InList { .. } => {}
@@ -188,6 +189,23 @@ pub fn from_plan(
input: Arc::new(inputs[0].clone()),
}),
},
+ LogicalPlan::Window {
+ // FIXME implement next
+ // filter_by_expr,
+ // FIXME implement next
+ // partition_by_expr,
+ // FIXME implement next
+ // order_by_expr,
+ // FIXME implement next
+ // window_frame,
+ window_expr,
+ schema,
+ ..
+ } => Ok(LogicalPlan::Window {
+ input: Arc::new(inputs[0].clone()),
+ window_expr: expr[0..window_expr.len()].to_vec(),
+ schema: schema.clone(),
+ }),
LogicalPlan::Aggregate {
group_expr, schema, ..
} => Ok(LogicalPlan::Aggregate {
@@ -247,6 +265,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
Expr::IsNotNull(e) => Ok(vec![e.as_ref().to_owned()]),
Expr::ScalarFunction { args, .. } => Ok(args.clone()),
Expr::ScalarUDF { args, .. } => Ok(args.clone()),
+ Expr::WindowFunction { args, .. } => Ok(args.clone()),
Expr::AggregateFunction { args, .. } => Ok(args.clone()),
Expr::AggregateUDF { args, .. } => Ok(args.clone()),
Expr::Case {
@@ -319,6 +338,10 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
fun: fun.clone(),
args: expressions.to_vec(),
}),
+ Expr::WindowFunction { fun, .. } => Ok(Expr::WindowFunction {
+ fun: fun.clone(),
+ args: expressions.to_vec(),
+ }),
Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction {
fun: fun.clone(),
args: expressions.to_vec(),
diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs
index 9417c7c..3607f29 100644
--- a/datafusion/src/physical_plan/aggregates.rs
+++ b/datafusion/src/physical_plan/aggregates.rs
@@ -37,7 +37,6 @@ use crate::physical_plan::expressions;
use arrow::datatypes::{DataType, Schema, TimeUnit};
use expressions::{avg_return_type, sum_return_type};
use std::{fmt, str::FromStr, sync::Arc};
-
/// the implementation of an aggregate function
pub type AccumulatorFunctionImplementation =
Arc<dyn Fn() -> Result<Box<dyn Accumulator>> + Send + Sync>;
@@ -183,7 +182,7 @@ static TIMESTAMPS: &[DataType] = &[
];
/// the signatures supported by the function `fun`.
-fn signature(fun: &AggregateFunction) -> Signature {
+pub fn signature(fun: &AggregateFunction) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
match fun {
AggregateFunction::Count => Signature::Any(1),
diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs
index e915b2c..c053229 100644
--- a/datafusion/src/physical_plan/mod.rs
+++ b/datafusion/src/physical_plan/mod.rs
@@ -442,6 +442,23 @@ pub trait AggregateExpr: Send + Sync + Debug {
}
}
+/// A window expression that:
+/// * knows its resulting field
+pub trait WindowExpr: Send + Sync + Debug {
+ /// Returns the window expression as [`Any`](std::any::Any) so that it can be
+ /// downcast to a specific implementation.
+ fn as_any(&self) -> &dyn Any;
+
+ /// the field of the final result of this window function.
+ fn field(&self) -> Result<Field>;
+
+ /// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default
+ /// implementation returns placeholder text.
+ fn name(&self) -> &str {
+ "WindowExpr: default name"
+ }
+}
+
/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and
/// generically accumulates values. An accumulator knows how to:
/// * update its state from inputs via `update`
@@ -530,3 +547,5 @@ pub mod udf;
#[cfg(feature = "unicode_expressions")]
pub mod unicode_expressions;
pub mod union;
+pub mod window_functions;
+pub mod windows;
diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs
index 9e7dc71..018925d 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -21,7 +21,7 @@ use std::sync::Arc;
use super::{
aggregates, cross_join::CrossJoinExec, empty::EmptyExec, expressions::binary,
- functions, hash_join::PartitionMode, udaf, union::UnionExec,
+ functions, hash_join::PartitionMode, udaf, union::UnionExec, windows,
};
use crate::execution::context::ExecutionContextState;
use crate::logical_plan::{
@@ -39,8 +39,11 @@ use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::repartition::RepartitionExec;
use crate::physical_plan::sort::SortExec;
use crate::physical_plan::udf;
+use crate::physical_plan::windows::WindowAggExec;
use crate::physical_plan::{hash_utils, Partitioning};
-use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, PhysicalPlanner};
+use crate::physical_plan::{
+ AggregateExpr, ExecutionPlan, PhysicalExpr, PhysicalPlanner, WindowExpr,
+};
use crate::prelude::JoinType;
use crate::scalar::ScalarValue;
use crate::variable::VarType;
@@ -48,10 +51,9 @@ use crate::{
error::{DataFusionError, Result},
physical_plan::displayable,
};
-use arrow::{compute::can_cast_types, datatypes::DataType};
-
use arrow::compute::SortOptions;
use arrow::datatypes::{Schema, SchemaRef};
+use arrow::{compute::can_cast_types, datatypes::DataType};
use expressions::col;
use log::debug;
@@ -139,6 +141,32 @@ impl DefaultPhysicalPlanner {
limit,
..
} => source.scan(projection, batch_size, filters, *limit),
+ LogicalPlan::Window {
+ input, window_expr, ..
+ } => {
+ // Initially need to perform the aggregate and then merge the partitions
+ let input_exec = self.create_initial_plan(input, ctx_state)?;
+ let input_schema = input_exec.schema();
+ let physical_input_schema = input_exec.as_ref().schema();
+ let logical_input_schema = input.as_ref().schema();
+ let window_expr = window_expr
+ .iter()
+ .map(|e| {
+ self.create_window_expr(
+ e,
+ &logical_input_schema,
+ &physical_input_schema,
+ ctx_state,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ Ok(Arc::new(WindowAggExec::try_new(
+ window_expr,
+ input_exec.clone(),
+ input_schema,
+ )?))
+ }
LogicalPlan::Aggregate {
input,
group_expr,
@@ -700,6 +728,37 @@ impl DefaultPhysicalPlanner {
}
}
+ /// Create a window expression from a logical expression
+ pub fn create_window_expr(
+ &self,
+ e: &Expr,
+ logical_input_schema: &DFSchema,
+ physical_input_schema: &Schema,
+ ctx_state: &ExecutionContextState,
+ ) -> Result<Arc<dyn WindowExpr>> {
+ // unpack aliased logical expressions, e.g. "sum(col) over () as total"
+ let (name, e) = match e {
+ Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()),
+ _ => (e.name(logical_input_schema)?, e),
+ };
+
+ match e {
+ Expr::WindowFunction { fun, args } => {
+ let args = args
+ .iter()
+ .map(|e| {
+ self.create_physical_expr(e, physical_input_schema, ctx_state)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ windows::create_window_expr(fun, &args, physical_input_schema, name)
+ }
+ other => Err(DataFusionError::Internal(format!(
+ "Invalid window expression '{:?}'",
+ other
+ ))),
+ }
+ }
+
/// Create an aggregate expression from a logical expression
pub fn create_aggregate_expr(
&self,
diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs
index 8229060..caa32cf 100644
--- a/datafusion/src/physical_plan/sort.rs
+++ b/datafusion/src/physical_plan/sort.rs
@@ -135,6 +135,7 @@ impl ExecutionPlan for SortExec {
"SortExec requires a single input partition".to_owned(),
));
}
+
let input = self.input.execute(0).await?;
Ok(Box::pin(SortStream::new(
diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs
new file mode 100644
index 0000000..65d5373
--- /dev/null
+++ b/datafusion/src/physical_plan/window_functions.rs
@@ -0,0 +1,342 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Window functions provide the ability to perform calculations across
+//! sets of rows that are related to the current query row.
+//!
+//! see also https://www.postgresql.org/docs/current/functions-window.html
+
+use crate::error::{DataFusionError, Result};
+use crate::physical_plan::{
+ aggregates, aggregates::AggregateFunction, functions::Signature,
+ type_coercion::data_types,
+};
+use arrow::datatypes::DataType;
+use std::{fmt, str::FromStr};
+
+/// WindowFunction
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum WindowFunction {
+ /// window function that leverages an aggregate function
+ AggregateFunction(AggregateFunction),
+ /// window function that leverages a built-in window function
+ BuiltInWindowFunction(BuiltInWindowFunction),
+}
+
+impl FromStr for WindowFunction {
+ type Err = DataFusionError;
+ fn from_str(name: &str) -> Result<WindowFunction> {
+ let name = name.to_lowercase();
+ if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) {
+ Ok(WindowFunction::AggregateFunction(aggregate))
+ } else if let Ok(built_in_function) =
+ BuiltInWindowFunction::from_str(name.as_str())
+ {
+ Ok(WindowFunction::BuiltInWindowFunction(built_in_function))
+ } else {
+ Err(DataFusionError::Plan(format!(
+ "There is no window function named {}",
+ name
+ )))
+ }
+ }
+}
+
+impl fmt::Display for BuiltInWindowFunction {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ BuiltInWindowFunction::RowNumber => write!(f, "ROW_NUMBER"),
+ BuiltInWindowFunction::Rank => write!(f, "RANK"),
+ BuiltInWindowFunction::DenseRank => write!(f, "DENSE_RANK"),
+ BuiltInWindowFunction::PercentRank => write!(f, "PERCENT_RANK"),
+ BuiltInWindowFunction::CumeDist => write!(f, "CUME_DIST"),
+ BuiltInWindowFunction::Ntile => write!(f, "NTILE"),
+ BuiltInWindowFunction::Lag => write!(f, "LAG"),
+ BuiltInWindowFunction::Lead => write!(f, "LEAD"),
+ BuiltInWindowFunction::FirstValue => write!(f, "FIRST_VALUE"),
+ BuiltInWindowFunction::LastValue => write!(f, "LAST_VALUE"),
+ BuiltInWindowFunction::NthValue => write!(f, "NTH_VALUE"),
+ }
+ }
+}
+
+impl fmt::Display for WindowFunction {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ WindowFunction::AggregateFunction(fun) => fun.fmt(f),
+ WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f),
+ }
+ }
+}
+
+/// An aggregate function that is part of a built-in window function
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum BuiltInWindowFunction {
+ /// number of the current row within its partition, counting from 1
+ RowNumber,
+ /// rank of the current row with gaps; same as row_number of its first peer
+ Rank,
+ /// ank of the current row without gaps; this function counts peer groups
+ DenseRank,
+ /// relative rank of the current row: (rank - 1) / (total rows - 1)
+ PercentRank,
+ /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows)
+ CumeDist,
+ /// integer ranging from 1 to the argument value, dividing the partition as equally as possible
+ Ntile,
+ /// returns value evaluated at the row that is offset rows before the current row within the partition;
+ /// if there is no such row, instead return default (which must be of the same type as value).
+ /// Both offset and default are evaluated with respect to the current row.
+ /// If omitted, offset defaults to 1 and default to null
+ Lag,
+ /// returns value evaluated at the row that is offset rows after the current row within the partition;
+ /// if there is no such row, instead return default (which must be of the same type as value).
+ /// Both offset and default are evaluated with respect to the current row.
+ /// If omitted, offset defaults to 1 and default to null
+ Lead,
+ /// returns value evaluated at the row that is the first row of the window frame
+ FirstValue,
+ /// returns value evaluated at the row that is the last row of the window frame
+ LastValue,
+ /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row
+ NthValue,
+}
+
+impl FromStr for BuiltInWindowFunction {
+ type Err = DataFusionError;
+ fn from_str(name: &str) -> Result<BuiltInWindowFunction> {
+ Ok(match name.to_uppercase().as_str() {
+ "ROW_NUMBER" => BuiltInWindowFunction::RowNumber,
+ "RANK" => BuiltInWindowFunction::Rank,
+ "DENSE_RANK" => BuiltInWindowFunction::DenseRank,
+ "PERCENT_RANK" => BuiltInWindowFunction::PercentRank,
+ "CUME_DIST" => BuiltInWindowFunction::CumeDist,
+ "NTILE" => BuiltInWindowFunction::Ntile,
+ "LAG" => BuiltInWindowFunction::Lag,
+ "LEAD" => BuiltInWindowFunction::Lead,
+ "FIRST_VALUE" => BuiltInWindowFunction::FirstValue,
+ "LAST_VALUE" => BuiltInWindowFunction::LastValue,
+ "NTH_VALUE" => BuiltInWindowFunction::NthValue,
+ _ => {
+ return Err(DataFusionError::Plan(format!(
+ "There is no built-in window function named {}",
+ name
+ )))
+ }
+ })
+ }
+}
+
+/// Returns the datatype of the window function
+pub fn return_type(fun: &WindowFunction, arg_types: &[DataType]) -> Result<DataType> {
+ // Note that this function *must* return the same type that the respective physical expression returns
+ // or the execution panics.
+
+ // verify that this is a valid set of data types for this function
+ data_types(arg_types, &signature(fun))?;
+
+ match fun {
+ WindowFunction::AggregateFunction(fun) => aggregates::return_type(fun, arg_types),
+ WindowFunction::BuiltInWindowFunction(fun) => match fun {
+ BuiltInWindowFunction::RowNumber
+ | BuiltInWindowFunction::Rank
+ | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64),
+ BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => {
+ Ok(DataType::Float64)
+ }
+ BuiltInWindowFunction::Ntile => Ok(DataType::UInt32),
+ BuiltInWindowFunction::Lag
+ | BuiltInWindowFunction::Lead
+ | BuiltInWindowFunction::FirstValue
+ | BuiltInWindowFunction::LastValue
+ | BuiltInWindowFunction::NthValue => Ok(arg_types[0].clone()),
+ },
+ }
+}
+
+/// the signatures supported by the function `fun`.
+fn signature(fun: &WindowFunction) -> Signature {
+ // note: the physical expression must accept the type returned by this function or the execution panics.
+ match fun {
+ WindowFunction::AggregateFunction(fun) => aggregates::signature(fun),
+ WindowFunction::BuiltInWindowFunction(fun) => match fun {
+ BuiltInWindowFunction::RowNumber
+ | BuiltInWindowFunction::Rank
+ | BuiltInWindowFunction::DenseRank
+ | BuiltInWindowFunction::PercentRank
+ | BuiltInWindowFunction::CumeDist => Signature::Any(0),
+ BuiltInWindowFunction::Lag
+ | BuiltInWindowFunction::Lead
+ | BuiltInWindowFunction::FirstValue
+ | BuiltInWindowFunction::LastValue => Signature::Any(1),
+ BuiltInWindowFunction::Ntile => Signature::Exact(vec![DataType::UInt64]),
+ BuiltInWindowFunction::NthValue => Signature::Any(2),
+ },
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_window_function_case_insensitive() -> Result<()> {
+ let names = vec![
+ "row_number",
+ "rank",
+ "dense_rank",
+ "percent_rank",
+ "cume_dist",
+ "ntile",
+ "lag",
+ "lead",
+ "first_value",
+ "last_value",
+ "nth_value",
+ "min",
+ "max",
+ "count",
+ "avg",
+ "sum",
+ ];
+ for name in names {
+ let fun = WindowFunction::from_str(name)?;
+ let fun2 = WindowFunction::from_str(name.to_uppercase().as_str())?;
+ assert_eq!(fun, fun2);
+ assert_eq!(fun.to_string(), name.to_uppercase());
+ }
+ Ok(())
+ }
+
+ #[test]
+ fn test_window_function_from_str() -> Result<()> {
+ assert_eq!(
+ WindowFunction::from_str("max")?,
+ WindowFunction::AggregateFunction(AggregateFunction::Max)
+ );
+ assert_eq!(
+ WindowFunction::from_str("min")?,
+ WindowFunction::AggregateFunction(AggregateFunction::Min)
+ );
+ assert_eq!(
+ WindowFunction::from_str("avg")?,
+ WindowFunction::AggregateFunction(AggregateFunction::Avg)
+ );
+ assert_eq!(
+ WindowFunction::from_str("cume_dist")?,
+ WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::CumeDist)
+ );
+ assert_eq!(
+ WindowFunction::from_str("first_value")?,
+ WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue)
+ );
+ assert_eq!(
+ WindowFunction::from_str("LAST_value")?,
+ WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue)
+ );
+ assert_eq!(
+ WindowFunction::from_str("LAG")?,
+ WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag)
+ );
+ assert_eq!(
+ WindowFunction::from_str("LEAD")?,
+ WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead)
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn test_count_return_type() -> Result<()> {
+ let fun = WindowFunction::from_str("count")?;
+ let observed = return_type(&fun, &[DataType::Utf8])?;
+ assert_eq!(DataType::UInt64, observed);
+
+ let observed = return_type(&fun, &[DataType::UInt64])?;
+ assert_eq!(DataType::UInt64, observed);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_first_value_return_type() -> Result<()> {
+ let fun = WindowFunction::from_str("first_value")?;
+ let observed = return_type(&fun, &[DataType::Utf8])?;
+ assert_eq!(DataType::Utf8, observed);
+
+ let observed = return_type(&fun, &[DataType::UInt64])?;
+ assert_eq!(DataType::UInt64, observed);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_last_value_return_type() -> Result<()> {
+ let fun = WindowFunction::from_str("last_value")?;
+ let observed = return_type(&fun, &[DataType::Utf8])?;
+ assert_eq!(DataType::Utf8, observed);
+
+ let observed = return_type(&fun, &[DataType::Float64])?;
+ assert_eq!(DataType::Float64, observed);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_lead_return_type() -> Result<()> {
+ let fun = WindowFunction::from_str("lead")?;
+ let observed = return_type(&fun, &[DataType::Utf8])?;
+ assert_eq!(DataType::Utf8, observed);
+
+ let observed = return_type(&fun, &[DataType::Float64])?;
+ assert_eq!(DataType::Float64, observed);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_lag_return_type() -> Result<()> {
+ let fun = WindowFunction::from_str("lag")?;
+ let observed = return_type(&fun, &[DataType::Utf8])?;
+ assert_eq!(DataType::Utf8, observed);
+
+ let observed = return_type(&fun, &[DataType::Float64])?;
+ assert_eq!(DataType::Float64, observed);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_nth_value_return_type() -> Result<()> {
+ let fun = WindowFunction::from_str("nth_value")?;
+ let observed = return_type(&fun, &[DataType::Utf8, DataType::UInt64])?;
+ assert_eq!(DataType::Utf8, observed);
+
+ let observed = return_type(&fun, &[DataType::Float64, DataType::UInt64])?;
+ assert_eq!(DataType::Float64, observed);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_cume_dist_return_type() -> Result<()> {
+ let fun = WindowFunction::from_str("cume_dist")?;
+ let observed = return_type(&fun, &[])?;
+ assert_eq!(DataType::Float64, observed);
+
+ Ok(())
+ }
+}
diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs
new file mode 100644
index 0000000..bdd25d6
--- /dev/null
+++ b/datafusion/src/physical_plan/windows.rs
@@ -0,0 +1,195 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Execution plan for window functions
+
+use crate::error::{DataFusionError, Result};
+use crate::physical_plan::{
+ aggregates, window_functions::WindowFunction, AggregateExpr, Distribution,
+ ExecutionPlan, Partitioning, PhysicalExpr, SendableRecordBatchStream, WindowExpr,
+};
+use arrow::datatypes::{Field, Schema, SchemaRef};
+use async_trait::async_trait;
+use std::any::Any;
+use std::sync::Arc;
+
+/// Window execution plan
+#[derive(Debug)]
+pub struct WindowAggExec {
+ /// Input plan
+ input: Arc<dyn ExecutionPlan>,
+ /// Window function expression
+ window_expr: Vec<Arc<dyn WindowExpr>>,
+ /// Schema after the window is run
+ schema: SchemaRef,
+ /// Schema before the window
+ input_schema: SchemaRef,
+}
+
+/// Create a physical expression for window function
+pub fn create_window_expr(
+ fun: &WindowFunction,
+ args: &[Arc<dyn PhysicalExpr>],
+ input_schema: &Schema,
+ name: String,
+) -> Result<Arc<dyn WindowExpr>> {
+ match fun {
+ WindowFunction::AggregateFunction(fun) => Ok(Arc::new(AggregateWindowExpr {
+ aggregate: aggregates::create_aggregate_expr(
+ fun,
+ false,
+ args,
+ input_schema,
+ name,
+ )?,
+ })),
+ WindowFunction::BuiltInWindowFunction(fun) => {
+ Err(DataFusionError::NotImplemented(format!(
+ "window function with {:?} not implemented",
+ fun
+ )))
+ }
+ }
+}
+
+/// A window expr that takes the form of a built in window function
+#[derive(Debug)]
+pub struct BuiltInWindowExpr {}
+
+/// A window expr that takes the form of an aggregate function
+#[derive(Debug)]
+pub struct AggregateWindowExpr {
+ aggregate: Arc<dyn AggregateExpr>,
+}
+
+impl WindowExpr for AggregateWindowExpr {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ &self.aggregate.name()
+ }
+
+ fn field(&self) -> Result<Field> {
+ self.aggregate.field()
+ }
+}
+
+fn create_schema(
+ input_schema: &Schema,
+ window_expr: &[Arc<dyn WindowExpr>],
+) -> Result<Schema> {
+ let mut fields = Vec::with_capacity(input_schema.fields().len() + window_expr.len());
+ for expr in window_expr {
+ fields.push(expr.field()?);
+ }
+ fields.extend_from_slice(input_schema.fields());
+ Ok(Schema::new(fields))
+}
+
+impl WindowAggExec {
+ /// Create a new execution plan for window aggregates
+ pub fn try_new(
+ window_expr: Vec<Arc<dyn WindowExpr>>,
+ input: Arc<dyn ExecutionPlan>,
+ input_schema: SchemaRef,
+ ) -> Result<Self> {
+ let schema = create_schema(&input.schema(), &window_expr)?;
+ let schema = Arc::new(schema);
+ Ok(WindowAggExec {
+ input,
+ window_expr,
+ schema,
+ input_schema,
+ })
+ }
+
+ /// Input plan
+ pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
+ &self.input
+ }
+
+ /// Get the input schema before any aggregates are applied
+ pub fn input_schema(&self) -> SchemaRef {
+ self.input_schema.clone()
+ }
+}
+
+#[async_trait]
+impl ExecutionPlan for WindowAggExec {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+
+ fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+ vec![self.input.clone()]
+ }
+
+ /// Get the output partitioning of this plan
+ fn output_partitioning(&self) -> Partitioning {
+ Partitioning::UnknownPartitioning(1)
+ }
+
+ fn required_child_distribution(&self) -> Distribution {
+ Distribution::SinglePartition
+ }
+
+ fn with_new_children(
+ &self,
+ children: Vec<Arc<dyn ExecutionPlan>>,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ match children.len() {
+ 1 => Ok(Arc::new(WindowAggExec::try_new(
+ self.window_expr.clone(),
+ children[0].clone(),
+ children[0].schema(),
+ )?)),
+ _ => Err(DataFusionError::Internal(
+ "WindowAggExec wrong number of children".to_owned(),
+ )),
+ }
+ }
+
+ async fn execute(&self, partition: usize) -> Result<SendableRecordBatchStream> {
+ if 0 != partition {
+ return Err(DataFusionError::Internal(format!(
+ "WindowAggExec invalid partition {}",
+ partition
+ )));
+ }
+
+ // window needs to operate on a single partition currently
+ if 1 != self.input.output_partitioning().partition_count() {
+ return Err(DataFusionError::Internal(
+ "WindowAggExec requires a single input partition".to_owned(),
+ ));
+ }
+
+ // let input = self.input.execute(0).await?;
+
+ Err(DataFusionError::NotImplemented(
+ "WindowAggExec::execute".to_owned(),
+ ))
+ }
+}
diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs
index 34c5901..a3027e5 100644
--- a/datafusion/src/sql/planner.rs
+++ b/datafusion/src/sql/planner.rs
@@ -35,7 +35,7 @@ use crate::{
};
use crate::{
physical_plan::udf::ScalarUDF,
- physical_plan::{aggregates, functions},
+ physical_plan::{aggregates, functions, window_functions},
sql::parser::{CreateExternalTable, FileType, Statement as DFStatement},
};
@@ -57,7 +57,8 @@ use super::{
parser::DFParser,
utils::{
can_columns_satisfy_exprs, expand_wildcard, expr_as_column_expr, extract_aliases,
- find_aggregate_exprs, find_column_exprs, rebase_expr, resolve_aliases_to_exprs,
+ find_aggregate_exprs, find_column_exprs, find_window_exprs, rebase_expr,
+ resolve_aliases_to_exprs,
},
};
@@ -413,7 +414,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
))
}
JoinConstraint::None => Err(DataFusionError::NotImplemented(
- "NONE contraint is not supported".to_string(),
+ "NONE constraint is not supported".to_string(),
)),
}
}
@@ -624,15 +625,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
plan
};
+ // window function
+ let window_func_exprs = find_window_exprs(&select_exprs_post_aggr);
+
+ let (plan, exprs) = if window_func_exprs.is_empty() {
+ (plan, select_exprs_post_aggr)
+ } else {
+ self.window(&plan, window_func_exprs, &select_exprs_post_aggr)?
+ };
+
let plan = if select.distinct {
return LogicalPlanBuilder::from(&plan)
- .aggregate(select_exprs_post_aggr, vec![])?
+ .aggregate(exprs, vec![])?
.build();
} else {
plan
};
- self.project(&plan, select_exprs_post_aggr)
+ self.project(&plan, exprs)
}
/// Returns the `Expr`'s corresponding to a SQL query's SELECT expressions.
@@ -657,10 +667,28 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
/// Wrap a plan in a projection
fn project(&self, input: &LogicalPlan, expr: Vec<Expr>) -> Result<LogicalPlan> {
self.validate_schema_satisfies_exprs(&input.schema(), &expr)?;
-
LogicalPlanBuilder::from(input).project(expr)?.build()
}
+ /// Wrap a plan in a window
+ fn window(
+ &self,
+ input: &LogicalPlan,
+ window_exprs: Vec<Expr>,
+ select_exprs: &[Expr],
+ ) -> Result<(LogicalPlan, Vec<Expr>)> {
+ let plan = LogicalPlanBuilder::from(input)
+ .window(window_exprs)?
+ .build()?;
+ let select_exprs = select_exprs
+ .iter()
+ .map(|expr| expr_as_column_expr(&expr, &plan))
+ .into_iter()
+ .collect::<Result<Vec<_>>>()?;
+ Ok((plan, select_exprs))
+ }
+
+ /// Wrap a plan in an aggregate
fn aggregate(
&self,
input: &LogicalPlan,
@@ -1059,70 +1087,69 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// first, scalar built-in
if let Ok(fun) = functions::BuiltinScalarFunction::from_str(&name) {
- let args = function
- .args
- .iter()
- .map(|a| self.sql_fn_arg_to_logical_expr(a))
- .collect::<Result<Vec<Expr>>>()?;
+ let args = self.function_args_to_expr(function)?;
return Ok(Expr::ScalarFunction { fun, args });
};
+ // then, window function
+ if let Some(window) = &function.over {
+ if window.partition_by.is_empty()
+ && window.order_by.is_empty()
+ && window.window_frame.is_none()
+ {
+ let fun = window_functions::WindowFunction::from_str(&name);
+ if let Ok(window_functions::WindowFunction::AggregateFunction(
+ aggregate_fun,
+ )) = fun
+ {
+ return Ok(Expr::WindowFunction {
+ fun: window_functions::WindowFunction::AggregateFunction(
+ aggregate_fun.clone(),
+ ),
+ args: self
+ .aggregate_fn_to_expr(&aggregate_fun, function)?,
+ });
+ } else if let Ok(
+ window_functions::WindowFunction::BuiltInWindowFunction(
+ window_fun,
+ ),
+ ) = fun
+ {
+ return Ok(Expr::WindowFunction {
+ fun: window_functions::WindowFunction::BuiltInWindowFunction(
+ window_fun,
+ ),
+ args:self.function_args_to_expr(function)?,
+ });
+ }
+ }
+ return Err(DataFusionError::NotImplemented(format!(
+ "Unsupported OVER clause ({})",
+ window
+ )));
+ }
+
// next, aggregate built-ins
if let Ok(fun) = aggregates::AggregateFunction::from_str(&name) {
- let args = if fun == aggregates::AggregateFunction::Count {
- function
- .args
- .iter()
- .map(|a| match a {
- FunctionArg::Unnamed(SQLExpr::Value(Value::Number(
- _,
- _,
- ))) => Ok(lit(1_u8)),
- FunctionArg::Unnamed(SQLExpr::Wildcard) => Ok(lit(1_u8)),
- _ => self.sql_fn_arg_to_logical_expr(a),
- })
- .collect::<Result<Vec<Expr>>>()?
- } else {
- function
- .args
- .iter()
- .map(|a| self.sql_fn_arg_to_logical_expr(a))
- .collect::<Result<Vec<Expr>>>()?
- };
-
- return match &function.over {
- Some(window) => Err(DataFusionError::NotImplemented(format!(
- "Unsupported OVER clause ({})",
- window
- ))),
- _ => Ok(Expr::AggregateFunction {
- fun,
- distinct: function.distinct,
- args,
- }),
- };
+ let args = self.aggregate_fn_to_expr(&fun, function)?;
+ return Ok(Expr::AggregateFunction {
+ fun,
+ distinct: function.distinct,
+ args,
+ });
};
// finally, user-defined functions (UDF) and UDAF
match self.schema_provider.get_function_meta(&name) {
Some(fm) => {
- let args = function
- .args
- .iter()
- .map(|a| self.sql_fn_arg_to_logical_expr(a))
- .collect::<Result<Vec<Expr>>>()?;
+ let args = self.function_args_to_expr(function)?;
Ok(Expr::ScalarUDF { fun: fm, args })
}
None => match self.schema_provider.get_aggregate_meta(&name) {
Some(fm) => {
- let args = function
- .args
- .iter()
- .map(|a| self.sql_fn_arg_to_logical_expr(a))
- .collect::<Result<Vec<Expr>>>()?;
-
+ let args = self.function_args_to_expr(function)?;
Ok(Expr::AggregateUDF { fun: fm, args })
}
_ => Err(DataFusionError::Plan(format!(
@@ -1142,6 +1169,39 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}
+ fn function_args_to_expr(
+ &self,
+ function: &sqlparser::ast::Function,
+ ) -> Result<Vec<Expr>> {
+ function
+ .args
+ .iter()
+ .map(|a| self.sql_fn_arg_to_logical_expr(a))
+ .collect::<Result<Vec<Expr>>>()
+ }
+
+ fn aggregate_fn_to_expr(
+ &self,
+ fun: &aggregates::AggregateFunction,
+ function: &sqlparser::ast::Function,
+ ) -> Result<Vec<Expr>> {
+ if *fun == aggregates::AggregateFunction::Count {
+ function
+ .args
+ .iter()
+ .map(|a| match a {
+ FunctionArg::Unnamed(SQLExpr::Value(Value::Number(_, _))) => {
+ Ok(lit(1_u8))
+ }
+ FunctionArg::Unnamed(SQLExpr::Wildcard) => Ok(lit(1_u8)),
+ _ => self.sql_fn_arg_to_logical_expr(a),
+ })
+ .collect::<Result<Vec<Expr>>>()
+ } else {
+ self.function_args_to_expr(function)
+ }
+ }
+
fn sql_interval_to_literal(
&self,
value: &str,
@@ -2641,13 +2701,34 @@ mod tests {
}
#[test]
- fn over_not_supported() {
+ fn empty_over() {
let sql = "SELECT order_id, MAX(order_id) OVER () from orders";
- let err = logical_plan(sql).expect_err("query should have failed");
- assert_eq!(
- "NotImplemented(\"Unsupported OVER clause ()\")",
- format!("{:?}", err)
- );
+ let expected = "\
+ Projection: #order_id, #MAX(order_id)\
+ \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[], orderBy=[]\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
+ }
+
+ #[test]
+ fn empty_over_plus() {
+ let sql = "SELECT order_id, MAX(qty * 1.1) OVER () from orders";
+ let expected = "\
+ Projection: #order_id, #MAX(qty Multiply Float64(1.1))\
+ \n WindowAggr: windowExpr=[[MAX(#qty Multiply Float64(1.1))]] partitionBy=[], orderBy=[]\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
+ }
+
+ #[test]
+ fn empty_over_multiple() {
+ let sql =
+ "SELECT order_id, MAX(qty) OVER (), min(qty) over (), aVg(qty) OVER () from orders";
+ let expected = "\
+ Projection: #order_id, #MAX(qty), #MIN(qty), #AVG(qty)\
+ \n WindowAggr: windowExpr=[[MAX(#qty), MIN(#qty), AVG(#qty)]] partitionBy=[], orderBy=[]\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
}
#[test]
@@ -2662,6 +2743,16 @@ mod tests {
}
#[test]
+ fn over_order_by_not_supported() {
+ let sql = "SELECT order_id, MAX(delivered) OVER (order BY order_id) from orders";
+ let err = logical_plan(sql).expect_err("query should have failed");
+ assert_eq!(
+ "NotImplemented(\"Unsupported OVER clause (ORDER BY order_id)\")",
+ format!("{:?}", err)
+ );
+ }
+
+ #[test]
fn only_union_all_supported() {
let sql = "SELECT order_id from orders EXCEPT SELECT order_id FROM orders";
let err = logical_plan(sql).expect_err("query should have failed");
diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs
index f41643d..70b9df0 100644
--- a/datafusion/src/sql/utils.rs
+++ b/datafusion/src/sql/utils.rs
@@ -46,6 +46,14 @@ pub(crate) fn find_aggregate_exprs(exprs: &[Expr]) -> Vec<Expr> {
})
}
+/// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence
+/// (depth first), with duplicates omitted.
+pub(crate) fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> {
+ find_exprs_in_exprs(exprs, &|nested_expr| {
+ matches!(nested_expr, Expr::WindowFunction { .. })
+ })
+}
+
/// Collect all deeply nested `Expr::Column`'s. They are returned in order of
/// appearance (depth first), with duplicates omitted.
pub(crate) fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
@@ -217,6 +225,13 @@ where
.collect::<Result<Vec<Expr>>>()?,
distinct: *distinct,
}),
+ Expr::WindowFunction { fun, args } => Ok(Expr::WindowFunction {
+ fun: fun.clone(),
+ args: args
+ .iter()
+ .map(|e| clone_with_replacement(e, replacement_fn))
+ .collect::<Result<Vec<Expr>>>()?,
+ }),
Expr::AggregateUDF { fun, args } => Ok(Expr::AggregateUDF {
fun: fun.clone(),
args: args
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index 17e0f13..e68c53b 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -797,6 +797,21 @@ async fn csv_query_count() -> Result<()> {
Ok(())
}
+// FIXME uncomment this when exec is done
+// #[tokio::test]
+// async fn csv_query_window_with_empty_over() -> Result<()> {
+// let mut ctx = ExecutionContext::new();
+// register_aggregate_csv(&mut ctx)?;
+// let sql = "SELECT count(c12) over () FROM aggregate_test_100";
+// // FIXME: so far the WindowAggExec is not implemented
+// // and the current behavior is to throw not implemented exception
+
+// let result = execute(&mut ctx, sql).await;
+// let expected: Vec<Vec<String>> = vec![];
+// assert_eq!(result, expected);
+// Ok(())
+// }
+
#[tokio::test]
async fn csv_query_group_by_int_count() -> Result<()> {
let mut ctx = ExecutionContext::new();