You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/06/07 10:14:32 UTC
[arrow-datafusion] branch master updated: closing up type checks
(#506)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 767eeb0 closing up type checks (#506)
767eeb0 is described below
commit 767eeb0a8bf17916aafb9a88abd52e7350acb596
Author: Jiayu Liu <Ji...@users.noreply.github.com>
AuthorDate: Mon Jun 7 18:14:25 2021 +0800
closing up type checks (#506)
---
ballista/rust/core/Cargo.toml | 2 +-
ballista/rust/core/proto/ballista.proto | 6 +-
.../rust/core/src/serde/logical_plan/from_proto.rs | 49 +--
.../rust/core/src/serde/logical_plan/to_proto.rs | 56 ++--
.../core/src/serde/physical_plan/from_proto.rs | 1 +
datafusion/src/logical_plan/expr.rs | 50 ++-
datafusion/src/optimizer/utils.rs | 5 +-
datafusion/src/physical_plan/mod.rs | 1 +
datafusion/src/physical_plan/planner.rs | 3 +-
datafusion/src/physical_plan/window_frames.rs | 337 +++++++++++++++++++++
datafusion/src/sql/planner.rs | 52 +++-
datafusion/src/sql/utils.rs | 12 +
12 files changed, 512 insertions(+), 62 deletions(-)
diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml
index 99822cf..1f23a2a 100644
--- a/ballista/rust/core/Cargo.toml
+++ b/ballista/rust/core/Cargo.toml
@@ -35,7 +35,7 @@ futures = "0.3"
log = "0.4"
prost = "0.7"
serde = {version = "1", features = ["derive"]}
-sqlparser = "0.8"
+sqlparser = "0.9.0"
tokio = "1.0"
tonic = "0.4"
uuid = { version = "0.8", features = ["v4"] }
diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto
index 0ed9f24..38d87e9 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -177,9 +177,9 @@ message WindowExprNode {
// repeated LogicalExprNode partition_by = 5;
repeated LogicalExprNode order_by = 6;
// repeated LogicalExprNode filter = 7;
- // oneof window_frame {
- // WindowFrame frame = 8;
- // }
+ oneof window_frame {
+ WindowFrame frame = 8;
+ }
}
message BetweenNode {
diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
index 662d9d0..4a19817 100644
--- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
@@ -20,12 +20,6 @@
use crate::error::BallistaError;
use crate::serde::{proto_error, protobuf};
use crate::{convert_box_required, convert_required};
-use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits};
-use std::{
- convert::{From, TryInto},
- unimplemented,
-};
-
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::logical_plan::{
abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin,
@@ -33,10 +27,17 @@ use datafusion::logical_plan::{
};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::physical_plan::csv::CsvReadOptions;
+use datafusion::physical_plan::window_frames::{
+ WindowFrame, WindowFrameBound, WindowFrameUnits,
+};
use datafusion::physical_plan::window_functions::BuiltInWindowFunction;
use datafusion::scalar::ScalarValue;
use protobuf::logical_plan_node::LogicalPlanType;
use protobuf::{logical_expr_node::ExprType, scalar_type};
+use std::{
+ convert::{From, TryInto},
+ unimplemented,
+};
// use uuid::Uuid;
@@ -83,20 +84,6 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, _>>()?;
-
- // let partition_by_expr = window
- // .partition_by_expr
- // .iter()
- // .map(|expr| expr.try_into())
- // .collect::<Result<Vec<_>, _>>()?;
- // let order_by_expr = window
- // .order_by_expr
- // .iter()
- // .map(|expr| expr.try_into())
- // .collect::<Result<Vec<_>, _>>()?;
- // // FIXME: add filter by expr
- // // FIXME: parse the window_frame data
- // let window_frame = None;
LogicalPlanBuilder::from(&input)
.window(window_expr)?
.build()
@@ -929,6 +916,15 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
.map(|e| e.try_into())
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
+ let window_frame = expr
+ .window_frame
+ .as_ref()
+ .map::<Result<WindowFrame, _>, _>(|e| match e {
+ window_expr_node::WindowFrame::Frame(frame) => {
+ frame.clone().try_into()
+ }
+ })
+ .transpose()?;
match window_function {
window_expr_node::WindowFunction::AggrFunction(i) => {
let aggr_function = protobuf::AggregateFunction::from_i32(*i)
@@ -945,6 +941,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
),
args: vec![parse_required_expr(&expr.expr)?],
order_by,
+ window_frame,
})
}
window_expr_node::WindowFunction::BuiltInFunction(i) => {
@@ -964,6 +961,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
),
args: vec![parse_required_expr(&expr.expr)?],
order_by,
+ window_frame,
})
}
}
@@ -1333,8 +1331,15 @@ impl TryFrom<protobuf::WindowFrame> for WindowFrame {
)
})?
.try_into()?;
- // FIXME parse end bound
- let end_bound = None;
+ let end_bound = window
+ .end_bound
+ .map(|end_bound| match end_bound {
+ protobuf::window_frame::EndBound::Bound(end_bound) => {
+ end_bound.try_into()
+ }
+ })
+ .transpose()?
+ .unwrap_or(WindowFrameBound::CurrentRow);
Ok(WindowFrame {
units,
start_bound,
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index d7734f0..5627003 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -24,12 +24,17 @@ use std::{
convert::{TryFrom, TryInto},
};
+use super::super::proto_error;
use crate::datasource::DfTableAdapter;
use crate::serde::{protobuf, BallistaError};
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
use datafusion::datasource::CsvFile;
use datafusion::logical_plan::{Expr, JoinType, LogicalPlan};
use datafusion::physical_plan::aggregates::AggregateFunction;
+use datafusion::physical_plan::functions::BuiltinScalarFunction;
+use datafusion::physical_plan::window_frames::{
+ WindowFrame, WindowFrameBound, WindowFrameUnits,
+};
use datafusion::physical_plan::window_functions::{
BuiltInWindowFunction, WindowFunction,
};
@@ -38,10 +43,6 @@ use protobuf::{
arrow_type, logical_expr_node::ExprType, scalar_type, DateUnit, PrimitiveScalarType,
ScalarListValue, ScalarType,
};
-use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits};
-
-use super::super::proto_error;
-use datafusion::physical_plan::functions::BuiltinScalarFunction;
impl protobuf::IntervalUnit {
pub fn from_arrow_interval_unit(interval_unit: &IntervalUnit) -> Self {
@@ -1007,6 +1008,7 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
ref fun,
ref args,
ref order_by,
+ ref window_frame,
..
} => {
let window_function = match fun {
@@ -1026,10 +1028,16 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
.iter()
.map(|e| e.try_into())
.collect::<Result<Vec<_>, _>>()?;
+ let window_frame = window_frame.map(|window_frame| {
+ protobuf::window_expr_node::WindowFrame::Frame(
+ window_frame.clone().into(),
+ )
+ });
let window_expr = Box::new(protobuf::WindowExprNode {
expr: Some(Box::new(arg.try_into()?)),
window_function: Some(window_function),
order_by,
+ window_frame,
});
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::WindowExpr(window_expr)),
@@ -1256,23 +1264,35 @@ impl From<WindowFrameUnits> for protobuf::WindowFrameUnits {
}
}
-impl TryFrom<WindowFrameBound> for protobuf::WindowFrameBound {
- type Error = BallistaError;
-
- fn try_from(_bound: WindowFrameBound) -> Result<Self, Self::Error> {
- Err(BallistaError::NotImplemented(
- "WindowFrameBound => protobuf::WindowFrameBound".to_owned(),
- ))
+impl From<WindowFrameBound> for protobuf::WindowFrameBound {
+ fn from(bound: WindowFrameBound) -> Self {
+ match bound {
+ WindowFrameBound::CurrentRow => protobuf::WindowFrameBound {
+ window_frame_bound_type: protobuf::WindowFrameBoundType::CurrentRow
+ .into(),
+ bound_value: None,
+ },
+ WindowFrameBound::Preceding(v) => protobuf::WindowFrameBound {
+ window_frame_bound_type: protobuf::WindowFrameBoundType::Preceding.into(),
+ bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value),
+ },
+ WindowFrameBound::Following(v) => protobuf::WindowFrameBound {
+ window_frame_bound_type: protobuf::WindowFrameBoundType::Following.into(),
+ bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value),
+ },
+ }
}
}
-impl TryFrom<WindowFrame> for protobuf::WindowFrame {
- type Error = BallistaError;
-
- fn try_from(_window: WindowFrame) -> Result<Self, Self::Error> {
- Err(BallistaError::NotImplemented(
- "WindowFrame => protobuf::WindowFrame".to_owned(),
- ))
+impl From<WindowFrame> for protobuf::WindowFrame {
+ fn from(window: WindowFrame) -> Self {
+ protobuf::WindowFrame {
+ window_frame_units: protobuf::WindowFrameUnits::from(window.units).into(),
+ start_bound: Some(window.start_bound.into()),
+ end_bound: Some(protobuf::window_frame::EndBound::Bound(
+ window.end_bound.into(),
+ )),
+ }
}
}
diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
index 2294431..5fcc971 100644
--- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
@@ -237,6 +237,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
fun,
args,
order_by,
+ ..
} => {
let arg = df_planner
.create_physical_expr(
diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs
index 5103d5d..bbc6ffa 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -19,22 +19,19 @@
//! such as `col = 5` or `SUM(col)`. See examples on the [`Expr`] struct.
pub use super::Operator;
-
-use std::fmt;
-use std::sync::Arc;
-
-use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction};
-use arrow::{compute::can_cast_types, datatypes::DataType};
-
use crate::error::{DataFusionError, Result};
use crate::logical_plan::{DFField, DFSchema};
use crate::physical_plan::{
aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF,
- window_functions,
+ window_frames, window_functions,
};
use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue};
+use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction};
+use arrow::{compute::can_cast_types, datatypes::DataType};
use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature};
use std::collections::HashSet;
+use std::fmt;
+use std::sync::Arc;
/// `Expr` is a central struct of DataFusion's query API, and
/// represent logical expressions such as `A + 1`, or `CAST(c1 AS
@@ -199,6 +196,8 @@ pub enum Expr {
args: Vec<Expr>,
/// List of order by expressions
order_by: Vec<Expr>,
+ /// Window frame
+ window_frame: Option<window_frames::WindowFrame>,
},
/// aggregate function
AggregateUDF {
@@ -735,10 +734,12 @@ impl Expr {
args,
fun,
order_by,
+ window_frame,
} => Expr::WindowFunction {
args: rewrite_vec(args, rewriter)?,
fun,
order_by: rewrite_vec(order_by, rewriter)?,
+ window_frame,
},
Expr::AggregateFunction {
args,
@@ -1283,8 +1284,23 @@ impl fmt::Debug for Expr {
Expr::ScalarUDF { fun, ref args, .. } => {
fmt_function(f, &fun.name, false, args)
}
- Expr::WindowFunction { fun, ref args, .. } => {
- fmt_function(f, &fun.to_string(), false, args)
+ Expr::WindowFunction {
+ fun,
+ ref args,
+ window_frame,
+ ..
+ } => {
+ fmt_function(f, &fun.to_string(), false, args)?;
+ if let Some(window_frame) = window_frame {
+ write!(
+ f,
+ " {} BETWEEN {} AND {}",
+ window_frame.units,
+ window_frame.start_bound,
+ window_frame.end_bound
+ )?;
+ }
+ Ok(())
}
Expr::AggregateFunction {
fun,
@@ -1401,8 +1417,18 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
Expr::ScalarUDF { fun, args, .. } => {
create_function_name(&fun.name, false, args, input_schema)
}
- Expr::WindowFunction { fun, args, .. } => {
- create_function_name(&fun.to_string(), false, args, input_schema)
+ Expr::WindowFunction {
+ fun,
+ args,
+ window_frame,
+ ..
+ } => {
+ let fun_name =
+ create_function_name(&fun.to_string(), false, args, input_schema)?;
+ Ok(match window_frame {
+ Some(window_frame) => format!("{} {}", fun_name, window_frame),
+ None => fun_name,
+ })
}
Expr::AggregateFunction {
fun,
diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs
index 2cb6506..65c95be 100644
--- a/datafusion/src/optimizer/utils.rs
+++ b/datafusion/src/optimizer/utils.rs
@@ -337,7 +337,9 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
fun: fun.clone(),
args: expressions.to_vec(),
}),
- Expr::WindowFunction { fun, .. } => {
+ Expr::WindowFunction {
+ fun, window_frame, ..
+ } => {
let index = expressions
.iter()
.position(|expr| {
@@ -353,6 +355,7 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
fun: fun.clone(),
args: expressions[..index].to_vec(),
order_by: expressions[index + 1..].to_vec(),
+ window_frame: *window_frame,
})
}
Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction {
diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs
index af6969c..490e028 100644
--- a/datafusion/src/physical_plan/mod.rs
+++ b/datafusion/src/physical_plan/mod.rs
@@ -617,5 +617,6 @@ pub mod udf;
#[cfg(feature = "unicode_expressions")]
pub mod unicode_expressions;
pub mod union;
+pub mod window_frames;
pub mod window_functions;
pub mod windows;
diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs
index 754ace0..d7451c7 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -17,8 +17,6 @@
//! Physical query planner
-use std::sync::Arc;
-
use super::{
aggregates, cross_join::CrossJoinExec, empty::EmptyExec, expressions::binary,
functions, hash_join::PartitionMode, udaf, union::UnionExec, windows,
@@ -56,6 +54,7 @@ use arrow::datatypes::{Schema, SchemaRef};
use arrow::{compute::can_cast_types, datatypes::DataType};
use expressions::col;
use log::debug;
+use std::sync::Arc;
/// This trait exposes the ability to plan an [`ExecutionPlan`] out of a [`LogicalPlan`].
pub trait ExtensionPlanner {
diff --git a/datafusion/src/physical_plan/window_frames.rs b/datafusion/src/physical_plan/window_frames.rs
new file mode 100644
index 0000000..f0be5a2
--- /dev/null
+++ b/datafusion/src/physical_plan/window_frames.rs
@@ -0,0 +1,337 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Window frame
+//!
+//! The frame-spec determines which output rows are read by an aggregate window function. The frame-spec consists of four parts:
+//! - A frame type - either ROWS, RANGE or GROUPS,
+//! - A starting frame boundary,
+//! - An ending frame boundary,
+//! - An EXCLUDE clause.
+
+use crate::error::{DataFusionError, Result};
+use sqlparser::ast;
+use std::cmp::Ordering;
+use std::convert::{From, TryFrom};
+use std::fmt;
+
+/// The frame-spec determines which output rows are read by an aggregate window function.
+///
+/// The ending frame boundary can be omitted (if the BETWEEN and AND keywords that surround the
+/// starting frame boundary are also omitted), in which case the ending frame boundary defaults to
+/// CURRENT ROW.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub struct WindowFrame {
+ /// A frame type - either ROWS, RANGE or GROUPS
+ pub units: WindowFrameUnits,
+ /// A starting frame boundary
+ pub start_bound: WindowFrameBound,
+ /// An ending frame boundary
+ pub end_bound: WindowFrameBound,
+}
+
+impl fmt::Display for WindowFrame {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "{} BETWEEN {} AND {}",
+ self.units, self.start_bound, self.end_bound
+ )?;
+ Ok(())
+ }
+}
+
+impl TryFrom<ast::WindowFrame> for WindowFrame {
+ type Error = DataFusionError;
+
+ fn try_from(value: ast::WindowFrame) -> Result<Self> {
+ let start_bound = value.start_bound.into();
+ let end_bound = value
+ .end_bound
+ .map(WindowFrameBound::from)
+ .unwrap_or(WindowFrameBound::CurrentRow);
+
+ if let WindowFrameBound::Following(None) = start_bound {
+ Err(DataFusionError::Execution(
+ "Invalid window frame: start bound cannot be unbounded following"
+ .to_owned(),
+ ))
+ } else if let WindowFrameBound::Preceding(None) = end_bound {
+ Err(DataFusionError::Execution(
+ "Invalid window frame: end bound cannot be unbounded preceding"
+ .to_owned(),
+ ))
+ } else if start_bound > end_bound {
+ Err(DataFusionError::Execution(format!(
+ "Invalid window frame: start bound ({}) cannot be larger than end bound ({})",
+ start_bound, end_bound
+ )))
+ } else {
+ let units = value.units.into();
+ Ok(Self {
+ units,
+ start_bound,
+ end_bound,
+ })
+ }
+ }
+}
+
+impl Default for WindowFrame {
+ fn default() -> Self {
+ WindowFrame {
+ units: WindowFrameUnits::Range,
+ start_bound: WindowFrameBound::Preceding(None),
+ end_bound: WindowFrameBound::CurrentRow,
+ }
+ }
+}
+
+/// There are five ways to describe starting and ending frame boundaries:
+///
+/// 1. UNBOUNDED PRECEDING
+/// 2. <expr> PRECEDING
+/// 3. CURRENT ROW
+/// 4. <expr> FOLLOWING
+/// 5. UNBOUNDED FOLLOWING
+///
+/// in this implementation we'll only allow <expr> to be u64 (i.e. no dynamic boundary)
+#[derive(Debug, Clone, Copy, Eq)]
+pub enum WindowFrameBound {
+ /// 1. UNBOUNDED PRECEDING
+ /// The frame boundary is the first row in the partition.
+ ///
+ /// 2. <expr> PRECEDING
+ /// <expr> must be a non-negative constant numeric expression. The boundary is a row that
+ /// is <expr> "units" prior to the current row.
+ Preceding(Option<u64>),
+ /// 3. The current row.
+ ///
+ /// For RANGE and GROUPS frame types, peers of the current row are also
+ /// included in the frame, unless specifically excluded by the EXCLUDE clause.
+ /// This is true regardless of whether CURRENT ROW is used as the starting or ending frame
+ /// boundary.
+ CurrentRow,
+ /// 4. This is the same as "<expr> PRECEDING" except that the boundary is <expr> units after the
+ /// current rather than before the current row.
+ ///
+ /// 5. UNBOUNDED FOLLOWING
+ /// The frame boundary is the last row in the partition.
+ Following(Option<u64>),
+}
+
+impl From<ast::WindowFrameBound> for WindowFrameBound {
+ fn from(value: ast::WindowFrameBound) -> Self {
+ match value {
+ ast::WindowFrameBound::Preceding(v) => Self::Preceding(v),
+ ast::WindowFrameBound::Following(v) => Self::Following(v),
+ ast::WindowFrameBound::CurrentRow => Self::CurrentRow,
+ }
+ }
+}
+
+impl fmt::Display for WindowFrameBound {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ WindowFrameBound::CurrentRow => f.write_str("CURRENT ROW"),
+ WindowFrameBound::Preceding(None) => f.write_str("UNBOUNDED PRECEDING"),
+ WindowFrameBound::Following(None) => f.write_str("UNBOUNDED FOLLOWING"),
+ WindowFrameBound::Preceding(Some(n)) => write!(f, "{} PRECEDING", n),
+ WindowFrameBound::Following(Some(n)) => write!(f, "{} FOLLOWING", n),
+ }
+ }
+}
+
+impl PartialEq for WindowFrameBound {
+ fn eq(&self, other: &Self) -> bool {
+ self.cmp(other) == Ordering::Equal
+ }
+}
+
+impl PartialOrd for WindowFrameBound {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+}
+
+impl Ord for WindowFrameBound {
+ fn cmp(&self, other: &Self) -> Ordering {
+ self.get_rank().cmp(&other.get_rank())
+ }
+}
+
+impl WindowFrameBound {
+ /// get the rank of this window frame bound.
+ ///
+ /// the rank is a tuple of (u8, u64) because we'll firstly compare the kind and then the value
+ /// which requires special handling e.g. with preceding the larger the value the smaller the
+ /// rank and also for 0 preceding / following it is the same as current row
+ fn get_rank(&self) -> (u8, u64) {
+ match self {
+ WindowFrameBound::Preceding(None) => (0, 0),
+ WindowFrameBound::Following(None) => (4, 0),
+ WindowFrameBound::Preceding(Some(0))
+ | WindowFrameBound::CurrentRow
+ | WindowFrameBound::Following(Some(0)) => (2, 0),
+ WindowFrameBound::Preceding(Some(v)) => (1, u64::MAX - *v),
+ WindowFrameBound::Following(Some(v)) => (3, *v),
+ }
+ }
+}
+
+/// There are three frame types: ROWS, GROUPS, and RANGE. The frame type determines how the
+/// starting and ending boundaries of the frame are measured.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum WindowFrameUnits {
+ /// The ROWS frame type means that the starting and ending boundaries for the frame are
+ /// determined by counting individual rows relative to the current row.
+ Rows,
+ /// The RANGE frame type requires that the ORDER BY clause of the window have exactly one
+ /// term. Call that term "X". With the RANGE frame type, the elements of the frame are
+ /// determined by computing the value of expression X for all rows in the partition and framing
+ /// those rows for which the value of X is within a certain range of the value of X for the
+ /// current row.
+ Range,
+ /// The GROUPS frame type means that the starting and ending boundaries are determine
+ /// by counting "groups" relative to the current group. A "group" is a set of rows that all have
+ /// equivalent values for all all terms of the window ORDER BY clause.
+ Groups,
+}
+
+impl fmt::Display for WindowFrameUnits {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ f.write_str(match self {
+ WindowFrameUnits::Rows => "ROWS",
+ WindowFrameUnits::Range => "RANGE",
+ WindowFrameUnits::Groups => "GROUPS",
+ })
+ }
+}
+
+impl From<ast::WindowFrameUnits> for WindowFrameUnits {
+ fn from(value: ast::WindowFrameUnits) -> Self {
+ match value {
+ ast::WindowFrameUnits::Range => Self::Range,
+ ast::WindowFrameUnits::Groups => Self::Groups,
+ ast::WindowFrameUnits::Rows => Self::Rows,
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_window_frame_creation() -> Result<()> {
+ let window_frame = ast::WindowFrame {
+ units: ast::WindowFrameUnits::Range,
+ start_bound: ast::WindowFrameBound::Following(None),
+ end_bound: None,
+ };
+ let result = WindowFrame::try_from(window_frame);
+ assert_eq!(
+ result.err().unwrap().to_string(),
+ "Execution error: Invalid window frame: start bound cannot be unbounded following".to_owned()
+ );
+
+ let window_frame = ast::WindowFrame {
+ units: ast::WindowFrameUnits::Range,
+ start_bound: ast::WindowFrameBound::Preceding(None),
+ end_bound: Some(ast::WindowFrameBound::Preceding(None)),
+ };
+ let result = WindowFrame::try_from(window_frame);
+ assert_eq!(
+ result.err().unwrap().to_string(),
+ "Execution error: Invalid window frame: end bound cannot be unbounded preceding".to_owned()
+ );
+
+ let window_frame = ast::WindowFrame {
+ units: ast::WindowFrameUnits::Range,
+ start_bound: ast::WindowFrameBound::Preceding(Some(1)),
+ end_bound: Some(ast::WindowFrameBound::Preceding(Some(2))),
+ };
+ let result = WindowFrame::try_from(window_frame);
+ assert_eq!(
+ result.err().unwrap().to_string(),
+ "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned()
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn test_eq() {
+ assert_eq!(
+ WindowFrameBound::Preceding(Some(0)),
+ WindowFrameBound::CurrentRow
+ );
+ assert_eq!(
+ WindowFrameBound::CurrentRow,
+ WindowFrameBound::Following(Some(0))
+ );
+ assert_eq!(
+ WindowFrameBound::Following(Some(2)),
+ WindowFrameBound::Following(Some(2))
+ );
+ assert_eq!(
+ WindowFrameBound::Following(None),
+ WindowFrameBound::Following(None)
+ );
+ assert_eq!(
+ WindowFrameBound::Preceding(Some(2)),
+ WindowFrameBound::Preceding(Some(2))
+ );
+ assert_eq!(
+ WindowFrameBound::Preceding(None),
+ WindowFrameBound::Preceding(None)
+ );
+ }
+
+ #[test]
+ fn test_ord() {
+ assert!(WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::CurrentRow);
+ // ! yes this is correct!
+ assert!(
+ WindowFrameBound::Preceding(Some(2)) < WindowFrameBound::Preceding(Some(1))
+ );
+ assert!(
+ WindowFrameBound::Preceding(Some(u64::MAX))
+ < WindowFrameBound::Preceding(Some(u64::MAX - 1))
+ );
+ assert!(
+ WindowFrameBound::Preceding(None)
+ < WindowFrameBound::Preceding(Some(1000000))
+ );
+ assert!(
+ WindowFrameBound::Preceding(None)
+ < WindowFrameBound::Preceding(Some(u64::MAX))
+ );
+ assert!(WindowFrameBound::Preceding(None) < WindowFrameBound::Following(Some(0)));
+ assert!(
+ WindowFrameBound::Preceding(Some(1)) < WindowFrameBound::Following(Some(1))
+ );
+ assert!(WindowFrameBound::CurrentRow < WindowFrameBound::Following(Some(1)));
+ assert!(
+ WindowFrameBound::Following(Some(1)) < WindowFrameBound::Following(Some(2))
+ );
+ assert!(WindowFrameBound::Following(Some(2)) < WindowFrameBound::Following(None));
+ assert!(
+ WindowFrameBound::Following(Some(u64::MAX))
+ < WindowFrameBound::Following(None)
+ );
+ }
+}
diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs
index aa6b5a9..6bf7b77 100644
--- a/datafusion/src/sql/planner.rs
+++ b/datafusion/src/sql/planner.rs
@@ -1121,13 +1121,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// then, window function
if let Some(window) = &function.over {
- if window.partition_by.is_empty() && window.window_frame.is_none() {
+ if window.partition_by.is_empty() {
let order_by = window
.order_by
.iter()
.map(|e| self.order_by_to_sort_expr(e))
.into_iter()
.collect::<Result<Vec<_>>>()?;
+ let window_frame = window
+ .window_frame
+ .as_ref()
+ .map(|window_frame| window_frame.clone().try_into())
+ .transpose()?;
let fun = window_functions::WindowFunction::from_str(&name);
if let Ok(window_functions::WindowFunction::AggregateFunction(
aggregate_fun,
@@ -1140,6 +1145,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
args: self
.aggregate_fn_to_expr(&aggregate_fun, function)?,
order_by,
+ window_frame,
});
} else if let Ok(
window_functions::WindowFunction::BuiltInWindowFunction(
@@ -1151,8 +1157,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
fun: window_functions::WindowFunction::BuiltInWindowFunction(
window_fun,
),
- args:self.function_args_to_expr(function)?,
- order_by
+ args: self.function_args_to_expr(function)?,
+ order_by,
+ window_frame,
});
}
}
@@ -2806,6 +2813,45 @@ mod tests {
quick_test(sql, expected);
}
+ #[test]
+ fn over_order_by_with_window_frame_double_end() {
+ let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE BETWEEN 3 PRECEDING and 3 FOLLOWING), MIN(qty) OVER (ORDER BY order_id DESC) from orders";
+ let expected = "\
+ Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING, #MIN(qty)\
+ \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING]] partitionBy=[]\
+ \n Sort: #order_id ASC NULLS FIRST\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n Sort: #order_id DESC NULLS FIRST\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
+ }
+
+ #[test]
+ fn over_order_by_with_window_frame_single_end() {
+ let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders";
+ let expected = "\
+ Projection: #order_id, #MAX(qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\
+ \n WindowAggr: windowExpr=[[MAX(#qty) RANGE BETWEEN 3 PRECEDING AND CURRENT ROW]] partitionBy=[]\
+ \n Sort: #order_id ASC NULLS FIRST\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n Sort: #order_id DESC NULLS FIRST\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
+ }
+
+ #[test]
+ fn over_order_by_with_window_frame_single_end_groups() {
+ let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders";
+ let expected = "\
+ Projection: #order_id, #MAX(qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW, #MIN(qty)\
+ \n WindowAggr: windowExpr=[[MAX(#qty) GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]] partitionBy=[]\
+ \n Sort: #order_id ASC NULLS FIRST\
+ \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\
+ \n Sort: #order_id DESC NULLS FIRST\
+ \n TableScan: orders projection=None";
+ quick_test(sql, expected);
+ }
+
/// psql result
/// ```
/// QUERY PLAN
diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs
index 80a25d0..7a5dc0d 100644
--- a/datafusion/src/sql/utils.rs
+++ b/datafusion/src/sql/utils.rs
@@ -239,6 +239,7 @@ where
fun,
args,
order_by,
+ window_frame,
} => Ok(Expr::WindowFunction {
fun: fun.clone(),
args: args
@@ -249,6 +250,7 @@ where
.iter()
.map(|e| clone_with_replacement(e, replacement_fn))
.collect::<Result<Vec<_>>>()?,
+ window_frame: *window_frame,
}),
Expr::AggregateUDF { fun, args } => Ok(Expr::AggregateUDF {
fun: fun.clone(),
@@ -453,21 +455,25 @@ mod tests {
fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
args: vec![col("name")],
order_by: vec![],
+ window_frame: None,
};
let max2 = Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
args: vec![col("name")],
order_by: vec![],
+ window_frame: None,
};
let min3 = Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Min),
args: vec![col("name")],
order_by: vec![],
+ window_frame: None,
};
let sum4 = Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
args: vec![col("age")],
order_by: vec![],
+ window_frame: None,
};
// FIXME use as_ref
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
@@ -500,21 +506,25 @@ mod tests {
fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
args: vec![col("name")],
order_by: vec![age_asc.clone(), name_desc.clone()],
+ window_frame: None,
};
let max2 = Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Max),
args: vec![col("name")],
order_by: vec![],
+ window_frame: None,
};
let min3 = Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Min),
args: vec![col("name")],
order_by: vec![age_asc.clone(), name_desc.clone()],
+ window_frame: None,
};
let sum4 = Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
args: vec![col("age")],
order_by: vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
+ window_frame: None,
};
// FIXME use as_ref
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
@@ -551,6 +561,7 @@ mod tests {
nulls_first: true,
},
],
+ window_frame: None,
},
Expr::WindowFunction {
fun: WindowFunction::AggregateFunction(AggregateFunction::Sum),
@@ -572,6 +583,7 @@ mod tests {
nulls_first: true,
},
],
+ window_frame: None,
},
];
let expected = vec![