You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/23 14:48:37 UTC
[arrow-datafusion] branch master updated: Support for bounded execution when window frame involves UNBOUNDED PRECEDING (#5003)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 624f02d15 Support for bounded execution when window frame involves UNBOUNDED PRECEDING (#5003)
624f02d15 is described below
commit 624f02d15787c3dfc4da40db2566c6161ded9bfe
Author: Mustafa Akur <10...@users.noreply.github.com>
AuthorDate: Mon Jan 23 17:48:31 2023 +0300
Support for bounded execution when window frame involves UNBOUNDED PRECEDING (#5003)
* initial support for aggregators
* Move common functionality to super trait for aggregates
* update tests
* bounded first_value, nth_value, last_value support
* nth_value bug fix
* minor changes
* Review and refactor
* Change naming: NonSliding -> Plain
* remove redundant check
* Remove window function state
* minor changes
* Remove unnecessary continue
* Address reviews
* Address reviews
* Address reviews
Co-authored-by: Mehmet Ozan Kabak <oz...@gmail.com>
---
.../windows/bounded_window_agg_exec.rs | 9 +-
datafusion/core/src/physical_plan/windows/mod.rs | 6 +-
datafusion/core/tests/sql/window.rs | 139 ++++++++++++++--
datafusion/core/tests/window_fuzz.rs | 156 +++++++++++++-----
datafusion/physical-expr/src/aggregate/average.rs | 10 +-
datafusion/physical-expr/src/expressions/mod.rs | 2 +-
datafusion/physical-expr/src/window/aggregate.rs | 143 +++++++++-------
datafusion/physical-expr/src/window/built_in.rs | 99 +++++++----
datafusion/physical-expr/src/window/mod.rs | 2 +-
datafusion/physical-expr/src/window/nth_value.rs | 42 +++--
.../src/window/partition_evaluator.rs | 6 +
.../physical-expr/src/window/sliding_aggregate.rs | 159 +++---------------
datafusion/physical-expr/src/window/window_expr.rs | 181 ++++++++++++++++++---
13 files changed, 629 insertions(+), 325 deletions(-)
diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs
index 5ed6a112c..13b6d88da 100644
--- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs
+++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs
@@ -548,13 +548,14 @@ impl SortedPartitionByBoundedWindowStream {
for window_agg_state in self.window_agg_states.iter_mut() {
window_agg_state.retain(|_, WindowState { state, .. }| !state.is_end);
for (partition_row, WindowState { state: value, .. }) in window_agg_state {
+ let n_prune =
+ min(value.window_frame_range.start, value.last_calculated_index);
if let Some(state) = n_prune_each_partition.get_mut(partition_row) {
- if value.window_frame_range.start < *state {
- *state = value.window_frame_range.start;
+ if n_prune < *state {
+ *state = n_prune;
}
} else {
- n_prune_each_partition
- .insert(partition_row.clone(), value.window_frame_range.start);
+ n_prune_each_partition.insert(partition_row.clone(), n_prune);
}
}
}
diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs
index 2d7aa0494..bdb9aa326 100644
--- a/datafusion/core/src/physical_plan/windows/mod.rs
+++ b/datafusion/core/src/physical_plan/windows/mod.rs
@@ -44,7 +44,7 @@ mod window_agg_exec;
pub use bounded_window_agg_exec::BoundedWindowAggExec;
pub use datafusion_physical_expr::window::{
- AggregateWindowExpr, BuiltInWindowExpr, WindowExpr,
+ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowExpr,
};
pub use window_agg_exec::WindowAggExec;
@@ -70,7 +70,7 @@ pub fn create_window_expr(
window_frame,
))
} else {
- Arc::new(AggregateWindowExpr::new(
+ Arc::new(PlainAggregateWindowExpr::new(
aggregate,
partition_by,
order_by,
@@ -84,7 +84,7 @@ pub fn create_window_expr(
order_by,
window_frame,
)),
- WindowFunction::AggregateUDF(fun) => Arc::new(AggregateWindowExpr::new(
+ WindowFunction::AggregateUDF(fun) => Arc::new(PlainAggregateWindowExpr::new(
udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?,
partition_by,
order_by,
diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs
index 24d86d527..f0fd04efd 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -1683,8 +1683,8 @@ async fn test_window_agg_sort() -> Result<()> {
vec![
"ProjectionExec: expr=[c9@1 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum2]",
" RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
" SortExec: [c9@1 ASC NULLS LAST,c8@0 ASC NULLS LAST]",
]
};
@@ -1716,8 +1716,8 @@ async fn over_order_by_sort_keys_sorting_prefix_compacting() -> Result<()> {
"ProjectionExec: expr=[c2@0 as c2, MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as MAX(aggregate_test_100.c9), SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(aggregate_test_100.c9), MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as MI [...]
" RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
" WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }]",
- " WindowAggExec: wdw=[MAX(aggregate_test_100.c9): Ok(Field { name: \"MAX(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
- " WindowAggExec: wdw=[MIN(aggregate_test_100.c9): Ok(Field { name: \"MIN(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
+ " BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c9): Ok(Field { name: \"MAX(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
+ " BoundedWindowAggExec: wdw=[MIN(aggregate_test_100.c9): Ok(Field { name: \"MIN(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
" SortExec: [c2@0 ASC NULLS LAST,c9@1 ASC NULLS LAST]",
]
};
@@ -1751,9 +1751,9 @@ async fn over_order_by_sort_keys_sorting_global_order_compacting() -> Result<()>
" ProjectionExec: expr=[c2@0 as c2, MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as MAX(aggregate_test_100.c9), SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(aggregate_test_100.c9), MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN U [...]
" RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
" WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)) }]",
- " WindowAggExec: wdw=[MAX(aggregate_test_100.c9): Ok(Field { name: \"MAX(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
+ " BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c9): Ok(Field { name: \"MAX(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
" SortExec: [c9@1 ASC NULLS LAST,c2@0 ASC NULLS LAST]",
- " WindowAggExec: wdw=[MIN(aggregate_test_100.c9): Ok(Field { name: \"MIN(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
+ " BoundedWindowAggExec: wdw=[MIN(aggregate_test_100.c9): Ok(Field { name: \"MIN(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
" SortExec: [c2@0 ASC NULLS LAST,c9@1 ASC NULLS LAST]",
]
};
@@ -2083,15 +2083,15 @@ async fn test_window_agg_complex_plan() -> Result<()> {
" RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
" GlobalLimitExec: skip=0, fetch=5",
" WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Prec [...]
- " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]",
+ " BoundedWindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]",
" SortExec: [c3@2 ASC NULLS LAST,c2@1 ASC NULLS LAST]",
- " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]",
+ " BoundedWindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]",
" SortExec: [c3@2 ASC NULLS LAST,c1@0 ASC]",
" WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]",
" WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start [...]
" WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, sta [...]
" WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, s [...]
- " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]",
+ " BoundedWindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]",
" SortExec: [c3@2 DESC,c1@0 ASC NULLS LAST]",
]
};
@@ -2241,7 +2241,7 @@ async fn test_window_agg_sort_orderby_reversed_binary_expr() -> Result<()> {
" RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
" GlobalLimitExec: skip=0, fetch=5",
" WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: \"SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_o [...]
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: \"SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED [...]
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: \"SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UN [...]
" SortExec: [CAST(c3@1 AS Int16) + c4@2 DESC,c9@3 DESC,c2@0 ASC NULLS LAST]",
]
};
@@ -2353,8 +2353,8 @@ async fn test_window_agg_sort_orderby_reversed_partitionby_reversed_plan() -> Re
"ProjectionExec: expr=[c3@1 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum2]",
" RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1",
" GlobalLimitExec: skip=0, fetch=5",
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow }]",
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
+ " BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow }]",
" SortExec: [c3@1 DESC,c9@2 DESC,c2@0 ASC NULLS LAST]",
]
};
@@ -2630,4 +2630,119 @@ mod tests {
assert_batches_eq!(expected, &actual);
Ok(())
}
+
+ #[tokio::test]
+ async fn test_source_sorted_unbounded_preceding() -> Result<()> {
+ let tmpdir = TempDir::new().unwrap();
+ let ctx = get_test_context(&tmpdir).await?;
+
+ let sql = "SELECT
+ SUM(inc_col) OVER(ORDER BY ts ASC RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING) as sum1,
+ SUM(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) as sum2,
+ MIN(inc_col) OVER(ORDER BY ts ASC RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING) as min1,
+ MIN(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) as min2,
+ MAX(inc_col) OVER(ORDER BY ts ASC RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING) as max1,
+ MAX(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) as max2,
+ COUNT(inc_col) OVER(ORDER BY ts ASC RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING) as count1,
+ COUNT(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) as count2,
+ AVG(inc_col) OVER(ORDER BY ts ASC RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING) as avg1,
+ AVG(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) as avg2
+ FROM annotated_data
+ ORDER BY inc_col ASC
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, min1@2 as min1, min2@3 as min2, max1@4 as max1, max2@5 as max2, count1@6 as count1, count2@7 as count2, avg1@8 as avg1, avg2@9 as avg2]",
+ " GlobalLimitExec: skip=0, fetch=5",
+ " SortExec: [inc_col@10 ASC NULLS LAST]",
+ " ProjectionExec: expr=[SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@7 as sum1, SUM(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as sum2, MIN(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@8 as min1, MIN(annotated_data.inc_col) ORDER BY [annotate [...]
+ " BoundedWindowAggExec: wdw=[SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)) }, MIN(annotated_data.inc_col): Ok(Field { name: \"MIN(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), fr [...]
+ " BoundedWindowAggExec: wdw=[SUM(annotated_data.inc_col): Ok(Field { name: \"SUM(annotated_data.inc_col)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)) }, MIN(annotated_data.inc_col): Ok(Field { name: \"MIN(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), [...]
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+------+------+------+------+------+------+--------+--------+-------------------+-------------------+",
+ "| sum1 | sum2 | min1 | min2 | max1 | max2 | count1 | count2 | avg1 | avg2 |",
+ "+------+------+------+------+------+------+--------+--------+-------------------+-------------------+",
+ "| 16 | 6 | 1 | 1 | 10 | 5 | 3 | 2 | 5.333333333333333 | 3 |",
+ "| 16 | 6 | 1 | 1 | 10 | 5 | 3 | 2 | 5.333333333333333 | 3 |",
+ "| 51 | 16 | 1 | 1 | 20 | 10 | 5 | 3 | 10.2 | 5.333333333333333 |",
+ "| 72 | 72 | 1 | 1 | 21 | 21 | 6 | 6 | 12 | 12 |",
+ "| 72 | 72 | 1 | 1 | 21 | 21 | 6 | 6 | 12 | 12 |",
+ "+------+------+------+------+------+------+--------+--------+-------------------+-------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_source_sorted_unbounded_preceding_builtin() -> Result<()> {
+ let tmpdir = TempDir::new().unwrap();
+ let ctx = get_test_context(&tmpdir).await?;
+
+ let sql = "SELECT
+ FIRST_VALUE(inc_col) OVER(ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) as first_value1,
+ FIRST_VALUE(inc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) as first_value2,
+ LAST_VALUE(inc_col) OVER(ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) as last_value1,
+ LAST_VALUE(inc_col) OVER(ORDER BY ts DESC ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) as last_value2,
+ NTH_VALUE(inc_col, 2) OVER(ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) as nth_value1
+ FROM annotated_data
+ ORDER BY inc_col ASC
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[first_value1@0 as first_value1, first_value2@1 as first_value2, last_value1@2 as last_value1, last_value2@3 as last_value2, nth_value1@4 as nth_value1]",
+ " GlobalLimitExec: skip=0, fetch=5",
+ " SortExec: [inc_col@5 ASC NULLS LAST]",
+ " ProjectionExec: expr=[FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as first_value1, FIRST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as first_value2, LAST_VALUE(annotated_data.inc_col) ORDER BY [annotated_data.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as last_value1, LAS [...]
+ " BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data.inc_col): Ok(Field { name: \"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)) }, LAST_VALUE(annotated_data.inc_col): Ok(Field { name: \"LAST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_orde [...]
+ " BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data.inc_col): Ok(Field { name: \"FIRST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)) }, LAST_VALUE(annotated_data.inc_col): Ok(Field { name: \"LAST_VALUE(annotated_data.inc_col)\", data_type: Int32, nullable: true, dict_id: 0, dict_is_or [...]
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+--------------+--------------+-------------+-------------+------------+",
+ "| first_value1 | first_value2 | last_value1 | last_value2 | nth_value1 |",
+ "+--------------+--------------+-------------+-------------+------------+",
+ "| 1 | 15 | 5 | 1 | 5 |",
+ "| 1 | 20 | 10 | 1 | 5 |",
+ "| 1 | 21 | 15 | 1 | 5 |",
+ "| 1 | 26 | 20 | 1 | 5 |",
+ "| 1 | 29 | 21 | 1 | 5 |",
+ "+--------------+--------------+-------------+-------------+------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+ }
}
diff --git a/datafusion/core/tests/window_fuzz.rs b/datafusion/core/tests/window_fuzz.rs
index 471484af2..d73248bfc 100644
--- a/datafusion/core/tests/window_fuzz.rs
+++ b/datafusion/core/tests/window_fuzz.rs
@@ -19,6 +19,7 @@ use std::sync::Arc;
use arrow::array::{ArrayRef, Int32Array};
use arrow::compute::{concat_batches, SortOptions};
+use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use hashbrown::HashMap;
@@ -38,7 +39,7 @@ use datafusion_expr::{
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_common::ScalarValue;
use datafusion_physical_expr::expressions::{col, lit};
-use datafusion_physical_expr::PhysicalSortExpr;
+use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use test_utils::add_empty_batches;
#[cfg(test)]
@@ -51,7 +52,7 @@ mod tests {
let distincts = vec![1, 100];
for distinct in distincts {
let mut handles = Vec::new();
- for i in 1..n {
+ for i in 0..n {
let job = tokio::spawn(run_window_test(
make_staggered_batches::<true>(1000, distinct, i),
i,
@@ -74,7 +75,7 @@ mod tests {
// since we have sorted pairs (a,b) to not violate per partition soring
// partition should be field a, order by should be field b
let mut handles = Vec::new();
- for i in 1..n {
+ for i in 0..n {
let job = tokio::spawn(run_window_test(
make_staggered_batches::<true>(1000, distinct, i),
i,
@@ -90,17 +91,11 @@ mod tests {
}
}
-/// Perform batch and running window same input
-/// and verify outputs of `WindowAggExec` and `BoundedWindowAggExec` are equal
-async fn run_window_test(
- input1: Vec<RecordBatch>,
- random_seed: u64,
- orderby_columns: Vec<&str>,
- partition_by_columns: Vec<&str>,
-) {
- let mut rng = StdRng::seed_from_u64(random_seed);
- let schema = input1[0].schema();
- let mut args = vec![col("x", &schema).unwrap()];
+fn get_random_function(
+ schema: &SchemaRef,
+ rng: &mut StdRng,
+) -> (WindowFunction, Vec<Arc<dyn PhysicalExpr>>, String) {
+ let mut args = vec![col("x", schema).unwrap()];
let mut window_fn_map = HashMap::new();
// HashMap values consists of tuple first element is WindowFunction, second is additional argument
// window function requires if any. For most of the window functions additional argument is empty
@@ -188,16 +183,44 @@ async fn run_window_test(
),
);
- let session_config = SessionConfig::new().with_batch_size(50);
- let ctx = SessionContext::with_config(session_config);
let rand_fn_idx = rng.gen_range(0..window_fn_map.len());
let fn_name = window_fn_map.keys().collect::<Vec<_>>()[rand_fn_idx];
let (window_fn, new_args) = window_fn_map.values().collect::<Vec<_>>()[rand_fn_idx];
for new_arg in new_args {
args.push(new_arg.clone());
}
- let preceding = rng.gen_range(0..50);
- let following = rng.gen_range(0..50);
+
+ (window_fn.clone(), args, fn_name.to_string())
+}
+
+fn get_random_window_frame(rng: &mut StdRng) -> WindowFrame {
+ struct Utils {
+ val: i32,
+ is_preceding: bool,
+ }
+ let first_bound = Utils {
+ val: rng.gen_range(0..50),
+ is_preceding: rng.gen_range(0..2) == 0,
+ };
+ let second_bound = Utils {
+ val: rng.gen_range(0..50),
+ is_preceding: rng.gen_range(0..2) == 0,
+ };
+ let (start_bound, end_bound) =
+ if first_bound.is_preceding == second_bound.is_preceding {
+ if (first_bound.val > second_bound.val && first_bound.is_preceding)
+ || (first_bound.val < second_bound.val && !first_bound.is_preceding)
+ {
+ (first_bound, second_bound)
+ } else {
+ (second_bound, first_bound)
+ }
+ } else if first_bound.is_preceding {
+ (first_bound, second_bound)
+ } else {
+ (second_bound, first_bound)
+ };
+ // 0 means Range, 1 means Rows, 2 means GROUPS
let rand_num = rng.gen_range(0..3);
let units = if rand_num < 1 {
WindowFrameUnits::Range
@@ -208,26 +231,83 @@ async fn run_window_test(
// TODO: once GROUPS handling is available, use WindowFrameUnits::GROUPS in randomized tests also.
WindowFrameUnits::Range
};
- let window_frame = match units {
+ match units {
// In range queries window frame boundaries should match column type
- WindowFrameUnits::Range => WindowFrame {
- units,
- start_bound: WindowFrameBound::Preceding(ScalarValue::Int32(Some(preceding))),
- end_bound: WindowFrameBound::Following(ScalarValue::Int32(Some(following))),
- },
+ WindowFrameUnits::Range => {
+ let start_bound = if start_bound.is_preceding {
+ WindowFrameBound::Preceding(ScalarValue::Int32(Some(start_bound.val)))
+ } else {
+ WindowFrameBound::Following(ScalarValue::Int32(Some(start_bound.val)))
+ };
+ let end_bound = if end_bound.is_preceding {
+ WindowFrameBound::Preceding(ScalarValue::Int32(Some(end_bound.val)))
+ } else {
+ WindowFrameBound::Following(ScalarValue::Int32(Some(end_bound.val)))
+ };
+ let mut window_frame = WindowFrame {
+ units,
+ start_bound,
+ end_bound,
+ };
+ // with 10% use unbounded preceding in tests
+ if rng.gen_range(0..10) == 0 {
+ window_frame.start_bound =
+ WindowFrameBound::Preceding(ScalarValue::Int32(None));
+ }
+ window_frame
+ }
// In window queries, window frame boundary should be Uint64
- WindowFrameUnits::Rows => WindowFrame {
- units,
- start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(
- preceding as u64,
- ))),
- end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(
- following as u64,
- ))),
- },
+ WindowFrameUnits::Rows => {
+ let start_bound = if start_bound.is_preceding {
+ WindowFrameBound::Preceding(ScalarValue::UInt64(Some(
+ start_bound.val as u64,
+ )))
+ } else {
+ WindowFrameBound::Following(ScalarValue::UInt64(Some(
+ start_bound.val as u64,
+ )))
+ };
+ let end_bound = if end_bound.is_preceding {
+ WindowFrameBound::Preceding(ScalarValue::UInt64(Some(
+ end_bound.val as u64,
+ )))
+ } else {
+ WindowFrameBound::Following(ScalarValue::UInt64(Some(
+ end_bound.val as u64,
+ )))
+ };
+ let mut window_frame = WindowFrame {
+ units,
+ start_bound,
+ end_bound,
+ };
+ // with 10% use unbounded preceding in tests
+ if rng.gen_range(0..10) == 0 {
+ window_frame.start_bound =
+ WindowFrameBound::Preceding(ScalarValue::UInt64(None));
+ }
+ window_frame
+ }
// Once GROUPS support is added construct window frame for this case also
_ => todo!(),
- };
+ }
+}
+
+/// Perform batch and running window same input
+/// and verify outputs of `WindowAggExec` and `BoundedWindowAggExec` are equal
+async fn run_window_test(
+ input1: Vec<RecordBatch>,
+ random_seed: u64,
+ orderby_columns: Vec<&str>,
+ partition_by_columns: Vec<&str>,
+) {
+ let mut rng = StdRng::seed_from_u64(random_seed);
+ let schema = input1[0].schema();
+ let session_config = SessionConfig::new().with_batch_size(50);
+ let ctx = SessionContext::with_config(session_config);
+ let (window_fn, args, fn_name) = get_random_function(&schema, &mut rng);
+
+ let window_frame = get_random_window_frame(&mut rng);
let mut orderby_exprs = vec![];
for column in orderby_columns {
orderby_exprs.push(PhysicalSortExpr {
@@ -257,8 +337,8 @@ async fn run_window_test(
let usual_window_exec = Arc::new(
WindowAggExec::try_new(
vec![create_window_expr(
- window_fn,
- fn_name.to_string(),
+ &window_fn,
+ fn_name.clone(),
&args,
&partitionby_exprs,
&orderby_exprs,
@@ -278,8 +358,8 @@ async fn run_window_test(
let running_window_exec = Arc::new(
BoundedWindowAggExec::try_new(
vec![create_window_expr(
- window_fn,
- fn_name.to_string(),
+ &window_fn,
+ fn_name,
&args,
&partitionby_exprs,
&orderby_exprs,
diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs
index e0a43aaa3..d755a4405 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -40,7 +40,7 @@ use datafusion_expr::Accumulator;
use datafusion_row::accessor::RowAccessor;
/// AVG aggregate expression
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Avg {
name: String,
expr: Arc<dyn PhysicalExpr>,
@@ -111,6 +111,10 @@ impl AggregateExpr for Avg {
is_row_accumulator_support_dtype(&self.data_type)
}
+ fn supports_bounded_execution(&self) -> bool {
+ true
+ }
+
fn create_row_accumulator(
&self,
start_index: usize,
@@ -121,6 +125,10 @@ impl AggregateExpr for Avg {
)))
}
+ fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
+ Some(Arc::new(self.clone()))
+ }
+
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(AvgAccumulator::try_new(&self.data_type)?))
}
diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs
index 411482cc4..63fb7b7d3 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -66,7 +66,7 @@ pub use crate::window::cume_dist::cume_dist;
pub use crate::window::cume_dist::CumeDist;
pub use crate::window::lead_lag::WindowShift;
pub use crate::window::lead_lag::{lag, lead};
-pub use crate::window::nth_value::{NthValue, NthValueKind};
+pub use crate::window::nth_value::NthValue;
pub use crate::window::ntile::Ntile;
pub use crate::window::rank::{dense_rank, percent_rank, rank};
pub use crate::window::rank::{Rank, RankType};
diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs
index fe725f2d7..e9b43cd07 100644
--- a/datafusion/physical-expr/src/window/aggregate.rs
+++ b/datafusion/physical-expr/src/window/aggregate.rs
@@ -18,37 +18,39 @@
//! Physical exec for aggregate window function expressions.
use std::any::Any;
-use std::iter::IntoIterator;
use std::ops::Range;
use std::sync::Arc;
use arrow::array::Array;
-use arrow::compute::SortOptions;
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, datatypes::Field};
-use datafusion_common::Result;
use datafusion_common::ScalarValue;
-use datafusion_expr::{WindowFrame, WindowFrameUnits};
+use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::{Accumulator, WindowFrame, WindowFrameUnits};
-use crate::window::window_expr::reverse_order_bys;
-use crate::window::SlidingAggregateWindowExpr;
-use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
-use crate::{window::WindowExpr, AggregateExpr};
-
-use super::window_frame_state::WindowFrameContext;
+use crate::window::window_expr::{reverse_order_bys, AggregateWindowExpr};
+use crate::window::{
+ PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr,
+};
+use crate::{expressions::PhysicalSortExpr, AggregateExpr, PhysicalExpr};
/// A window expr that takes the form of an aggregate function
+/// Aggregate Window Expressions that have the form
+/// `OVER({ROWS | RANGE| GROUPS} BETWEEN UNBOUNDED PRECEDING AND ...)`
+/// e.g cumulative window frames uses `PlainAggregateWindowExpr`. Where as Aggregate Window Expressions
+/// that have the form `OVER({ROWS | RANGE| GROUPS} BETWEEN M {PRECEDING| FOLLOWING} AND ...)`
+/// e.g sliding window frames uses `SlidingAggregateWindowExpr`.
#[derive(Debug)]
-pub struct AggregateWindowExpr {
+pub struct PlainAggregateWindowExpr {
aggregate: Arc<dyn AggregateExpr>,
partition_by: Vec<Arc<dyn PhysicalExpr>>,
order_by: Vec<PhysicalSortExpr>,
window_frame: Arc<WindowFrame>,
}
-impl AggregateWindowExpr {
- /// create a new aggregate window function expression
+impl PlainAggregateWindowExpr {
+ /// Create a new aggregate window function expression
pub fn new(
aggregate: Arc<dyn AggregateExpr>,
partition_by: &[Arc<dyn PhysicalExpr>],
@@ -72,8 +74,7 @@ impl AggregateWindowExpr {
/// peer based evaluation based on the fact that batch is pre-sorted given the sort columns
/// and then per partition point we'll evaluate the peer group (e.g. SUM or MAX gives the same
/// results for peers) and concatenate the results.
-
-impl WindowExpr for AggregateWindowExpr {
+impl WindowExpr for PlainAggregateWindowExpr {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
@@ -92,47 +93,36 @@ impl WindowExpr for AggregateWindowExpr {
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
- let sort_options: Vec<SortOptions> =
- self.order_by.iter().map(|o| o.options).collect();
- let mut row_wise_results: Vec<ScalarValue> = vec![];
-
- let mut accumulator = self.aggregate.create_accumulator()?;
- let length = batch.num_rows();
- let (values, order_bys) = self.get_values_orderbys(batch)?;
-
- let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
- let mut last_range = Range { start: 0, end: 0 };
-
- // We iterate on each row to perform a running calculation.
- // First, cur_range is calculated, then it is compared with last_range.
- for i in 0..length {
- let cur_range = window_frame_ctx.calculate_range(
- &order_bys,
- &sort_options,
- length,
- i,
- &last_range,
- )?;
- let value = if cur_range.end == cur_range.start {
- // We produce None if the window is empty.
- ScalarValue::try_from(self.aggregate.field()?.data_type())?
- } else {
- // Accumulate any new rows that have entered the window:
- let update_bound = cur_range.end - last_range.end;
- if update_bound > 0 {
- let update: Vec<ArrayRef> = values
- .iter()
- .map(|v| v.slice(last_range.end, update_bound))
- .collect();
- accumulator.update_batch(&update)?
- }
- accumulator.evaluate()?
- };
- row_wise_results.push(value);
- last_range = cur_range;
- }
+ self.aggregate_evaluate(batch)
+ }
- ScalarValue::iter_to_array(row_wise_results.into_iter())
+ fn evaluate_stateful(
+ &self,
+ partition_batches: &PartitionBatches,
+ window_agg_state: &mut PartitionWindowAggStates,
+ ) -> Result<()> {
+ self.aggregate_evaluate_stateful(partition_batches, window_agg_state)?;
+
+ // Update window frame range for each partition. As we know that
+ // non-sliding aggregations will never call `retract_batch`, this value
+ // can safely increase, and we can remove "old" parts of the state.
+ // This enables us to run queries involving UNBOUNDED PRECEDING frames
+ // using bounded memory for suitable aggregations.
+ for partition_row in partition_batches.keys() {
+ let window_state =
+ window_agg_state.get_mut(partition_row).ok_or_else(|| {
+ DataFusionError::Execution("Cannot find state".to_string())
+ })?;
+ let mut state = &mut window_state.state;
+ if self.window_frame.start_bound.is_unbounded() {
+ state.window_frame_range.start = if state.window_frame_range.end >= 1 {
+ state.window_frame_range.end - 1
+ } else {
+ 0
+ };
+ }
+ }
+ Ok(())
}
fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
@@ -151,7 +141,7 @@ impl WindowExpr for AggregateWindowExpr {
self.aggregate.reverse_expr().map(|reverse_expr| {
let reverse_window_frame = self.window_frame.reverse();
if reverse_window_frame.start_bound.is_unbounded() {
- Arc::new(AggregateWindowExpr::new(
+ Arc::new(PlainAggregateWindowExpr::new(
reverse_expr,
&self.partition_by.clone(),
&reverse_order_bys(&self.order_by),
@@ -171,8 +161,47 @@ impl WindowExpr for AggregateWindowExpr {
fn uses_bounded_memory(&self) -> bool {
// NOTE: Currently, groups queries do not support the bounded memory variant.
self.aggregate.supports_bounded_execution()
- && !self.window_frame.start_bound.is_unbounded()
&& !self.window_frame.end_bound.is_unbounded()
&& !matches!(self.window_frame.units, WindowFrameUnits::Groups)
}
}
+
+impl AggregateWindowExpr for PlainAggregateWindowExpr {
+ fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ self.aggregate.create_accumulator()
+ }
+
+ /// For a given range, calculate accumulation result inside the range on
+ /// `value_slice` and update accumulator state.
+ // We assume that `cur_range` contains `last_range` and their start points
+ // are same. In summary if `last_range` is `Range{start: a,end: b}` and
+ // `cur_range` is `Range{start: a1, end: b1}`, it is guaranteed that a1=a and b1>=b.
+ fn get_aggregate_result_inside_range(
+ &self,
+ last_range: &Range<usize>,
+ cur_range: &Range<usize>,
+ value_slice: &[ArrayRef],
+ accumulator: &mut Box<dyn Accumulator>,
+ ) -> Result<ScalarValue> {
+ let value = if cur_range.start == cur_range.end {
+ // We produce None if the window is empty.
+ ScalarValue::try_from(self.aggregate.field()?.data_type())?
+ } else {
+ // Accumulate any new rows that have entered the window:
+ let update_bound = cur_range.end - last_range.end;
+ // A non-sliding aggregation only processes new data, it never
+ // deals with expiring data as its starting point is always the
+ // same point (i.e. the beginning of the table/frame). Hence, we
+ // do not call `retract_batch`.
+ if update_bound > 0 {
+ let update: Vec<ArrayRef> = value_slice
+ .iter()
+ .map(|v| v.slice(last_range.end, update_bound))
+ .collect();
+ accumulator.update_batch(&update)?
+ }
+ accumulator.evaluate()?
+ };
+ Ok(value)
+ }
+}
diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs
index b73e2b8de..b53164f66 100644
--- a/datafusion/physical-expr/src/window/built_in.rs
+++ b/datafusion/physical-expr/src/window/built_in.rs
@@ -17,25 +17,26 @@
//! Physical exec for built-in window function expressions.
+use std::any::Any;
+use std::ops::Range;
+use std::sync::Arc;
+
use super::window_frame_state::WindowFrameContext;
use super::BuiltInWindowFunctionExpr;
use super::WindowExpr;
use crate::window::window_expr::{
- reverse_order_bys, BuiltinWindowState, WindowFn, WindowFunctionState,
+ reverse_order_bys, BuiltinWindowState, NthValueKind, NthValueState, WindowFn,
};
use crate::window::{
PartitionBatches, PartitionWindowAggStates, WindowAggState, WindowState,
};
use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
+use arrow::array::{new_empty_array, Array, ArrayRef};
use arrow::compute::{concat, SortOptions};
+use arrow::datatypes::Field;
use arrow::record_batch::RecordBatch;
-use arrow::{array::ArrayRef, datatypes::Field};
-use datafusion_common::ScalarValue;
-use datafusion_common::{DataFusionError, Result};
+use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::{WindowFrame, WindowFrameUnits};
-use std::any::Any;
-use std::ops::Range;
-use std::sync::Arc;
/// A window expr that takes the form of a built in window function
#[derive(Debug)]
@@ -145,12 +146,7 @@ impl WindowExpr for BuiltInWindowExpr {
window_agg_state.insert(
partition_row.clone(),
WindowState {
- state: WindowAggState::new(
- out_type,
- WindowFunctionState::BuiltinWindowState(
- BuiltinWindowState::Default,
- ),
- )?,
+ state: WindowAggState::new(out_type)?,
window_fn: WindowFn::Builtin(evaluator),
},
);
@@ -170,16 +166,17 @@ impl WindowExpr for BuiltInWindowExpr {
self.get_values_orderbys(&partition_batch_state.record_batch)?;
// We iterate on each row to perform a running calculation.
- let num_rows = partition_batch_state.record_batch.num_rows();
- let mut last_range = state.window_frame_range.clone();
+ let record_batch = &partition_batch_state.record_batch;
+ let num_rows = record_batch.num_rows();
let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
let sort_partition_points = if evaluator.include_rank() {
- let columns = self.sort_columns(&partition_batch_state.record_batch)?;
+ let columns = self.sort_columns(record_batch)?;
self.evaluate_partition_points(num_rows, &columns)?
} else {
vec![]
};
let mut row_wise_results: Vec<ScalarValue> = vec![];
+ let mut last_range = state.window_frame_range.clone();
for idx in state.last_calculated_index..num_rows {
state.window_frame_range = if self.expr.uses_window_frame() {
window_frame_ctx.calculate_range(
@@ -194,34 +191,33 @@ impl WindowExpr for BuiltInWindowExpr {
}?;
evaluator.update_state(state, &order_bys, &sort_partition_points)?;
- // Exit if range end index is length, need kind of flag to stop
- if state.window_frame_range.end == num_rows
- && !partition_batch_state.is_end
- {
- state.window_frame_range = last_range.clone();
+ let frame_range = &state.window_frame_range;
+ // Exit if the range extends all the way:
+ if frame_range.end == num_rows && !state.is_end {
break;
}
- let frame_range = &state.window_frame_range;
- row_wise_results.push(if frame_range.start == frame_range.end {
- // We produce None if the window is empty.
- ScalarValue::try_from(out_type)
- } else {
- evaluator.evaluate_stateful(&values)
- }?);
- last_range = frame_range.clone();
- state.last_calculated_index = idx + 1;
+ row_wise_results.push(evaluator.evaluate_stateful(&values)?);
+ last_range.clone_from(frame_range);
+ state.last_calculated_index += 1;
}
state.window_frame_range = last_range;
let out_col = if row_wise_results.is_empty() {
- ScalarValue::try_from(out_type)?.to_array_of_size(0)
+ new_empty_array(out_type)
} else {
ScalarValue::iter_to_array(row_wise_results.into_iter())?
};
state.out_col = concat(&[&state.out_col, &out_col])?;
state.n_row_result_missing = num_rows - state.last_calculated_index;
- state.window_function_state =
- WindowFunctionState::BuiltinWindowState(evaluator.state()?);
+ if self.window_frame.start_bound.is_unbounded() {
+ let mut evaluator_state = evaluator.state()?;
+ if let BuiltinWindowState::NthValue(nth_value_state) =
+ &mut evaluator_state
+ {
+ memoize_nth_value(state, nth_value_state)?;
+ evaluator.set_state(&evaluator_state)?;
+ }
+ }
}
Ok(())
}
@@ -245,8 +241,41 @@ impl WindowExpr for BuiltInWindowExpr {
// NOTE: Currently, groups queries do not support the bounded memory variant.
self.expr.supports_bounded_execution()
&& (!self.expr.uses_window_frame()
- || !(self.window_frame.start_bound.is_unbounded()
- || self.window_frame.end_bound.is_unbounded()
+ || !(self.window_frame.end_bound.is_unbounded()
|| matches!(self.window_frame.units, WindowFrameUnits::Groups)))
}
}
+
+// When the window frame has a fixed beginning (e.g UNBOUNDED PRECEDING), for
+// FIRST_VALUE, LAST_VALUE and NTH_VALUE functions: we can memoize result.
+// Once result is calculated it will always stay same. Hence, we do not
+// need to keep past data as we process the entire dataset. This feature
+// enables us to prune rows from table.
+fn memoize_nth_value(
+ state: &mut WindowAggState,
+ nth_value_state: &mut NthValueState,
+) -> Result<()> {
+ let out = &state.out_col;
+ let size = out.len();
+ let (is_prunable, new_prunable) = match nth_value_state.kind {
+ NthValueKind::First => {
+ let n_range = state.window_frame_range.end - state.window_frame_range.start;
+ (n_range > 0 && size > 0, true)
+ }
+ NthValueKind::Last => (true, false),
+ NthValueKind::Nth(n) => {
+ let n_range = state.window_frame_range.end - state.window_frame_range.start;
+ (n_range >= (n as usize) && size >= (n as usize), true)
+ }
+ };
+ if is_prunable {
+ if nth_value_state.finalized_result.is_none() && new_prunable {
+ let result = ScalarValue::try_from_array(out, size - 1)?;
+ nth_value_state.finalized_result = Some(result);
+ }
+ if state.window_frame_range.end > 0 {
+ state.window_frame_range.start = state.window_frame_range.end - 1;
+ }
+ }
+ Ok(())
+}
diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs
index 35036a6db..4c8b8b5a4 100644
--- a/datafusion/physical-expr/src/window/mod.rs
+++ b/datafusion/physical-expr/src/window/mod.rs
@@ -29,7 +29,7 @@ mod sliding_aggregate;
mod window_expr;
mod window_frame_state;
-pub use aggregate::AggregateWindowExpr;
+pub use aggregate::PlainAggregateWindowExpr;
pub use built_in::BuiltInWindowExpr;
pub use built_in_window_function_expr::BuiltInWindowFunctionExpr;
pub use sliding_aggregate::SlidingAggregateWindowExpr;
diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs
index c40a4fa7d..a1f03f13b 100644
--- a/datafusion/physical-expr/src/window/nth_value.rs
+++ b/datafusion/physical-expr/src/window/nth_value.rs
@@ -19,7 +19,7 @@
//! that can evaluated at runtime during query execution
use crate::window::partition_evaluator::PartitionEvaluator;
-use crate::window::window_expr::{BuiltinWindowState, NthValueState};
+use crate::window::window_expr::{BuiltinWindowState, NthValueKind, NthValueState};
use crate::window::{BuiltInWindowFunctionExpr, WindowAggState};
use crate::PhysicalExpr;
use arrow::array::{Array, ArrayRef};
@@ -30,14 +30,6 @@ use std::any::Any;
use std::ops::Range;
use std::sync::Arc;
-/// nth_value kind
-#[derive(Debug, Copy, Clone)]
-pub enum NthValueKind {
- First,
- Last,
- Nth(u32),
-}
-
/// nth_value expression
#[derive(Debug)]
pub struct NthValue {
@@ -122,10 +114,12 @@ impl BuiltInWindowFunctionExpr for NthValue {
}
fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
- Ok(Box::new(NthValueEvaluator {
- state: NthValueState::default(),
+ let state = NthValueState {
+ range: Default::default(),
+ finalized_result: None,
kind: self.kind,
- }))
+ };
+ Ok(Box::new(NthValueEvaluator { state }))
}
fn supports_bounded_execution(&self) -> bool {
@@ -155,7 +149,6 @@ impl BuiltInWindowFunctionExpr for NthValue {
#[derive(Debug)]
pub(crate) struct NthValueEvaluator {
state: NthValueState,
- kind: NthValueKind,
}
impl PartitionEvaluator for NthValueEvaluator {
@@ -171,12 +164,23 @@ impl PartitionEvaluator for NthValueEvaluator {
_sort_partition_points: &[Range<usize>],
) -> Result<()> {
// If we do not use state, update_state does nothing
- self.state.range = state.window_frame_range.clone();
+ self.state.range.clone_from(&state.window_frame_range);
+ Ok(())
+ }
+
+ fn set_state(&mut self, state: &BuiltinWindowState) -> Result<()> {
+ if let BuiltinWindowState::NthValue(nth_value_state) = state {
+ self.state = nth_value_state.clone()
+ }
Ok(())
}
fn evaluate_stateful(&mut self, values: &[ArrayRef]) -> Result<ScalarValue> {
- self.evaluate_inside_range(values, &self.state.range)
+ if let Some(ref result) = self.state.finalized_result {
+ Ok(result.clone())
+ } else {
+ self.evaluate_inside_range(values, &self.state.range)
+ }
}
fn evaluate_inside_range(
@@ -184,10 +188,14 @@ impl PartitionEvaluator for NthValueEvaluator {
values: &[ArrayRef],
range: &Range<usize>,
) -> Result<ScalarValue> {
- // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take single column, values will have size 1
+ // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take a single column, values will have size 1.
let arr = &values[0];
let n_range = range.end - range.start;
- match self.kind {
+ if n_range == 0 {
+ // We produce None if the window is empty.
+ return ScalarValue::try_from(arr.data_type());
+ }
+ match self.state.kind {
NthValueKind::First => ScalarValue::try_from_array(arr, range.start),
NthValueKind::Last => ScalarValue::try_from_array(arr, range.end - 1),
NthValueKind::Nth(n) => {
diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs
index 44fbb2d94..7887d1412 100644
--- a/datafusion/physical-expr/src/window/partition_evaluator.rs
+++ b/datafusion/physical-expr/src/window/partition_evaluator.rs
@@ -48,6 +48,12 @@ pub trait PartitionEvaluator: Debug + Send {
Ok(())
}
+ fn set_state(&mut self, _state: &BuiltinWindowState) -> Result<()> {
+ Err(DataFusionError::NotImplemented(
+ "set_state is not implemented for this window function".to_string(),
+ ))
+ }
+
fn get_range(&self, _state: &WindowAggState, _n_rows: usize) -> Result<Range<usize>> {
Err(DataFusionError::NotImplemented(
"get_range is not implemented for this window function".to_string(),
diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs
index a429f658c..0723f05c5 100644
--- a/datafusion/physical-expr/src/window/sliding_aggregate.rs
+++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs
@@ -18,30 +18,28 @@
//! Physical exec for aggregate window function expressions.
use std::any::Any;
-use std::iter::IntoIterator;
use std::ops::Range;
use std::sync::Arc;
-use arrow::array::Array;
-use arrow::compute::{concat, SortOptions};
+use arrow::array::{Array, ArrayRef};
+use arrow::datatypes::Field;
use arrow::record_batch::RecordBatch;
-use arrow::{array::ArrayRef, datatypes::Field};
-use datafusion_common::ScalarValue;
-use datafusion_common::{DataFusionError, Result};
+use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{Accumulator, WindowFrame, WindowFrameUnits};
-use crate::window::window_expr::{reverse_order_bys, WindowFn, WindowFunctionState};
+use crate::window::window_expr::{reverse_order_bys, AggregateWindowExpr};
use crate::window::{
- AggregateWindowExpr, PartitionBatches, PartitionWindowAggStates, WindowAggState,
- WindowState,
+ PartitionBatches, PartitionWindowAggStates, PlainAggregateWindowExpr, WindowExpr,
};
-use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
-use crate::{window::WindowExpr, AggregateExpr};
-
-use super::window_frame_state::WindowFrameContext;
+use crate::{expressions::PhysicalSortExpr, AggregateExpr, PhysicalExpr};
/// A window expr that takes the form of an aggregate function
+/// Aggregate Window Expressions that have the form
+/// `OVER({ROWS | RANGE| GROUPS} BETWEEN UNBOUNDED PRECEDING AND ...)`
+/// e.g cumulative window frames uses `PlainAggregateWindowExpr`. Where as Aggregate Window Expressions
+/// that have the form `OVER({ROWS | RANGE| GROUPS} BETWEEN M {PRECEDING| FOLLOWING} AND ...)`
+/// e.g sliding window frames uses `SlidingAggregateWindowExpr`.
#[derive(Debug)]
pub struct SlidingAggregateWindowExpr {
aggregate: Arc<dyn AggregateExpr>,
@@ -51,7 +49,7 @@ pub struct SlidingAggregateWindowExpr {
}
impl SlidingAggregateWindowExpr {
- /// create a new aggregate window function expression
+ /// Create a new (sliding) aggregate window function expression.
pub fn new(
aggregate: Arc<dyn AggregateExpr>,
partition_by: &[Arc<dyn PhysicalExpr>],
@@ -66,7 +64,7 @@ impl SlidingAggregateWindowExpr {
}
}
- /// Get aggregate expr of AggregateWindowExpr
+ /// Get the [AggregateExpr] of this object.
pub fn get_aggregate_expr(&self) -> &Arc<dyn AggregateExpr> {
&self.aggregate
}
@@ -95,19 +93,7 @@ impl WindowExpr for SlidingAggregateWindowExpr {
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
- let mut accumulator = self.aggregate.create_sliding_accumulator()?;
-
- let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
- let mut last_range = Range { start: 0, end: 0 };
- let mut idx = 0;
- self.get_result_column(
- &mut accumulator,
- batch,
- &mut window_frame_ctx,
- &mut last_range,
- &mut idx,
- true,
- )
+ self.aggregate_evaluate(batch)
}
fn evaluate_stateful(
@@ -115,55 +101,7 @@ impl WindowExpr for SlidingAggregateWindowExpr {
partition_batches: &PartitionBatches,
window_agg_state: &mut PartitionWindowAggStates,
) -> Result<()> {
- let field = self.aggregate.field()?;
- let out_type = field.data_type();
- for (partition_row, partition_batch_state) in partition_batches.iter() {
- if !window_agg_state.contains_key(partition_row) {
- let accumulator = self.aggregate.create_sliding_accumulator()?;
- window_agg_state.insert(
- partition_row.clone(),
- WindowState {
- state: WindowAggState::new(
- out_type,
- WindowFunctionState::AggregateState(vec![]),
- )?,
- window_fn: WindowFn::Aggregate(accumulator),
- },
- );
- };
- let window_state =
- window_agg_state.get_mut(partition_row).ok_or_else(|| {
- DataFusionError::Execution("Cannot find state".to_string())
- })?;
- let accumulator = match &mut window_state.window_fn {
- WindowFn::Aggregate(accumulator) => accumulator,
- _ => unreachable!(),
- };
- let mut state = &mut window_state.state;
- state.is_end = partition_batch_state.is_end;
-
- let mut idx = state.last_calculated_index;
- let mut last_range = state.window_frame_range.clone();
- let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
- let out_col = self.get_result_column(
- accumulator,
- &partition_batch_state.record_batch,
- &mut window_frame_ctx,
- &mut last_range,
- &mut idx,
- state.is_end,
- )?;
- state.last_calculated_index = idx;
- state.window_frame_range = last_range.clone();
-
- state.out_col = concat(&[&state.out_col, &out_col])?;
- let num_rows = partition_batch_state.record_batch.num_rows();
- state.n_row_result_missing = num_rows - state.last_calculated_index;
-
- state.window_function_state =
- WindowFunctionState::AggregateState(accumulator.state()?);
- }
- Ok(())
+ self.aggregate_evaluate_stateful(partition_batches, window_agg_state)
}
fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
@@ -182,7 +120,7 @@ impl WindowExpr for SlidingAggregateWindowExpr {
self.aggregate.reverse_expr().map(|reverse_expr| {
let reverse_window_frame = self.window_frame.reverse();
if reverse_window_frame.start_bound.is_unbounded() {
- Arc::new(AggregateWindowExpr::new(
+ Arc::new(PlainAggregateWindowExpr::new(
reverse_expr,
&self.partition_by.clone(),
&reverse_order_bys(&self.order_by),
@@ -202,15 +140,18 @@ impl WindowExpr for SlidingAggregateWindowExpr {
fn uses_bounded_memory(&self) -> bool {
// NOTE: Currently, groups queries do not support the bounded memory variant.
self.aggregate.supports_bounded_execution()
- && !self.window_frame.start_bound.is_unbounded()
&& !self.window_frame.end_bound.is_unbounded()
&& !matches!(self.window_frame.units, WindowFrameUnits::Groups)
}
}
-impl SlidingAggregateWindowExpr {
- /// For given range calculate accumulator result inside range on value_slice and
- /// update accumulator state
+impl AggregateWindowExpr for SlidingAggregateWindowExpr {
+ fn get_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ self.aggregate.create_sliding_accumulator()
+ }
+
+ /// Given current range and the last range, calculates the accumulator
+ /// result for the range of interest.
fn get_aggregate_result_inside_range(
&self,
last_range: &Range<usize>,
@@ -218,9 +159,9 @@ impl SlidingAggregateWindowExpr {
value_slice: &[ArrayRef],
accumulator: &mut Box<dyn Accumulator>,
) -> Result<ScalarValue> {
- let value = if cur_range.start == cur_range.end {
+ if cur_range.start == cur_range.end {
// We produce None if the window is empty.
- ScalarValue::try_from(self.aggregate.field()?.data_type())?
+ ScalarValue::try_from(self.aggregate.field()?.data_type())
} else {
// Accumulate any new rows that have entered the window:
let update_bound = cur_range.end - last_range.end;
@@ -240,55 +181,7 @@ impl SlidingAggregateWindowExpr {
.collect();
accumulator.retract_batch(&retract)?
}
- accumulator.evaluate()?
- };
- Ok(value)
- }
-
- fn get_result_column(
- &self,
- accumulator: &mut Box<dyn Accumulator>,
- record_batch: &RecordBatch,
- window_frame_ctx: &mut WindowFrameContext,
- last_range: &mut Range<usize>,
- idx: &mut usize,
- is_end: bool,
- ) -> Result<ArrayRef> {
- let (values, order_bys) = self.get_values_orderbys(record_batch)?;
- // We iterate on each row to perform a running calculation.
- let length = values[0].len();
- let sort_options: Vec<SortOptions> =
- self.order_by.iter().map(|o| o.options).collect();
- let mut row_wise_results: Vec<ScalarValue> = vec![];
- let field = self.aggregate.field()?;
- let out_type = field.data_type();
- while *idx < length {
- let cur_range = window_frame_ctx.calculate_range(
- &order_bys,
- &sort_options,
- length,
- *idx,
- last_range,
- )?;
- // Exit if range end index is length, need kind of flag to stop
- if cur_range.end == length && !is_end {
- break;
- }
- let value = self.get_aggregate_result_inside_range(
- last_range,
- &cur_range,
- &values,
- accumulator,
- )?;
- row_wise_results.push(value);
- last_range.start = cur_range.start;
- last_range.end = cur_range.end;
- *idx += 1;
+ accumulator.evaluate()
}
- Ok(if row_wise_results.is_empty() {
- ScalarValue::try_from(out_type)?.to_array_of_size(0)
- } else {
- ScalarValue::iter_to_array(row_wise_results.into_iter())?
- })
}
}
diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs
index 656b6723b..a38d5de54 100644
--- a/datafusion/physical-expr/src/window/window_expr.rs
+++ b/datafusion/physical-expr/src/window/window_expr.rs
@@ -16,11 +16,14 @@
// under the License.
use crate::window::partition_evaluator::PartitionEvaluator;
+use crate::window::window_frame_state::WindowFrameContext;
use crate::{PhysicalExpr, PhysicalSortExpr};
+use arrow::array::{new_empty_array, ArrayRef};
use arrow::compute::kernels::partition::lexicographical_partition_ranges;
use arrow::compute::kernels::sort::SortColumn;
+use arrow::compute::{concat, SortOptions};
+use arrow::datatypes::Field;
use arrow::record_batch::RecordBatch;
-use arrow::{array::ArrayRef, datatypes::Field};
use arrow_schema::DataType;
use datafusion_common::{reverse_sort_options, DataFusionError, Result, ScalarValue};
use datafusion_expr::{Accumulator, WindowFrame};
@@ -64,7 +67,8 @@ pub trait WindowExpr: Send + Sync + Debug {
/// evaluate the window function values against the batch
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
- /// evaluate the window function values against the batch
+ /// Evaluate the window function against the batch. This function facilitates
+ /// stateful, bounded-memory implementations.
fn evaluate_stateful(
&self,
_partition_batches: &PartitionBatches,
@@ -139,6 +143,136 @@ pub trait WindowExpr: Send + Sync + Debug {
fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>>;
}
+/// Trait for different `AggregateWindowExpr`s (`PlainAggregateWindowExpr`, `SlidingAggregateWindowExpr`)
+pub trait AggregateWindowExpr: WindowExpr {
+ /// Get the accumulator for the window expression. Note that distinct
+ /// window expressions may return distinct accumulators; e.g. sliding
+ /// (non-sliding) expressions will return sliding (normal) accumulators.
+ fn get_accumulator(&self) -> Result<Box<dyn Accumulator>>;
+
+ /// Given current range and the last range, calculates the accumulator
+ /// result for the range of interest.
+ fn get_aggregate_result_inside_range(
+ &self,
+ last_range: &Range<usize>,
+ cur_range: &Range<usize>,
+ value_slice: &[ArrayRef],
+ accumulator: &mut Box<dyn Accumulator>,
+ ) -> Result<ScalarValue>;
+
+ /// Evaluates the window function against the batch.
+ fn aggregate_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
+ let mut window_frame_ctx = WindowFrameContext::new(self.get_window_frame());
+ let mut accumulator = self.get_accumulator()?;
+ let mut last_range = Range { start: 0, end: 0 };
+ let mut idx = 0;
+ self.get_result_column(
+ &mut accumulator,
+ batch,
+ &mut window_frame_ctx,
+ &mut last_range,
+ &mut idx,
+ false,
+ )
+ }
+
+ /// Statefully evaluates the window function against the batch. Maintains
+ /// state so that it can work incrementally over multiple chunks.
+ fn aggregate_evaluate_stateful(
+ &self,
+ partition_batches: &PartitionBatches,
+ window_agg_state: &mut PartitionWindowAggStates,
+ ) -> Result<()> {
+ let field = self.field()?;
+ let out_type = field.data_type();
+ for (partition_row, partition_batch_state) in partition_batches.iter() {
+ if !window_agg_state.contains_key(partition_row) {
+ let accumulator = self.get_accumulator()?;
+ window_agg_state.insert(
+ partition_row.clone(),
+ WindowState {
+ state: WindowAggState::new(out_type)?,
+ window_fn: WindowFn::Aggregate(accumulator),
+ },
+ );
+ };
+ let window_state =
+ window_agg_state.get_mut(partition_row).ok_or_else(|| {
+ DataFusionError::Execution("Cannot find state".to_string())
+ })?;
+ let accumulator = match &mut window_state.window_fn {
+ WindowFn::Aggregate(accumulator) => accumulator,
+ _ => unreachable!(),
+ };
+ let mut state = &mut window_state.state;
+
+ let record_batch = &partition_batch_state.record_batch;
+ let mut window_frame_ctx = WindowFrameContext::new(self.get_window_frame());
+ let out_col = self.get_result_column(
+ accumulator,
+ record_batch,
+ &mut window_frame_ctx,
+ &mut state.window_frame_range,
+ &mut state.last_calculated_index,
+ !partition_batch_state.is_end,
+ )?;
+ state.is_end = partition_batch_state.is_end;
+ state.out_col = concat(&[&state.out_col, &out_col])?;
+ state.n_row_result_missing =
+ record_batch.num_rows() - state.last_calculated_index;
+ }
+ Ok(())
+ }
+
+ /// Calculates the window expression result for the given record batch.
+ /// Assumes that `record_batch` belongs to a single partition.
+ fn get_result_column(
+ &self,
+ accumulator: &mut Box<dyn Accumulator>,
+ record_batch: &RecordBatch,
+ window_frame_ctx: &mut WindowFrameContext,
+ last_range: &mut Range<usize>,
+ idx: &mut usize,
+ not_end: bool,
+ ) -> Result<ArrayRef> {
+ let (values, order_bys) = self.get_values_orderbys(record_batch)?;
+ // We iterate on each row to perform a running calculation.
+ let length = values[0].len();
+ let sort_options: Vec<SortOptions> =
+ self.order_by().iter().map(|o| o.options).collect();
+ let mut row_wise_results: Vec<ScalarValue> = vec![];
+ while *idx < length {
+ let cur_range = window_frame_ctx.calculate_range(
+ &order_bys,
+ &sort_options,
+ length,
+ *idx,
+ last_range,
+ )?;
+ // Exit if the range extends all the way:
+ if cur_range.end == length && not_end {
+ break;
+ }
+ let value = self.get_aggregate_result_inside_range(
+ last_range,
+ &cur_range,
+ &values,
+ accumulator,
+ )?;
+ last_range.clone_from(&cur_range);
+ row_wise_results.push(value);
+ *idx += 1;
+ }
+ if row_wise_results.is_empty() {
+ let field = self.field()?;
+ let out_type = field.data_type();
+ Ok(new_empty_array(out_type))
+ } else {
+ ScalarValue::iter_to_array(row_wise_results.into_iter())
+ }
+ }
+}
+
/// Reverses the ORDER BY expression, which is useful during equivalent window
/// expression construction. For instance, 'ORDER BY a ASC, NULLS LAST' turns into
/// 'ORDER BY a DESC, NULLS FIRST'.
@@ -158,8 +292,7 @@ pub enum WindowFn {
Aggregate(Box<dyn Accumulator>),
}
-/// State for RANK(percent_rank, rank, dense_rank)
-/// builtin window function
+/// State for the RANK(percent_rank, rank, dense_rank) built-in window function.
#[derive(Debug, Clone, Default)]
pub struct RankState {
/// The last values for rank as these values change, we increase n_rank
@@ -170,15 +303,33 @@ pub struct RankState {
pub n_rank: usize,
}
-/// State for 'ROW_NUMBER' builtin window function
+/// State for the 'ROW_NUMBER' built-in window function.
#[derive(Debug, Clone, Default)]
pub struct NumRowsState {
pub n_rows: usize,
}
-#[derive(Debug, Clone, Default)]
+/// Tag to differentiate special use cases of the NTH_VALUE built-in window function.
+#[derive(Debug, Copy, Clone)]
+pub enum NthValueKind {
+ First,
+ Last,
+ Nth(u32),
+}
+
+#[derive(Debug, Clone)]
pub struct NthValueState {
pub range: Range<usize>,
+ // In certain cases, we can finalize the result early. Consider this usage:
+ // ```
+ // FIRST_VALUE(increasing_col) OVER window AS my_first_value
+ // WINDOW (ORDER BY ts ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) AS window
+ // ```
+ // The result will always be the first entry in the table. We can store such
+ // early-finalizing results and then just reuse them as necessary. This opens
+ // opportunities to prune our datasets.
+ pub finalized_result: Option<ScalarValue>,
+ pub kind: NthValueKind,
}
#[derive(Debug, Clone, Default)]
@@ -195,15 +346,6 @@ pub enum BuiltinWindowState {
#[default]
Default,
}
-#[derive(Debug)]
-pub enum WindowFunctionState {
- /// Different Aggregate functions may have different state definitions
- /// In [Accumulator] trait, [fn state(&self) -> Result<Vec<ScalarValue>>] implementation
- /// dictates that.
- AggregateState(Vec<ScalarValue>),
- /// BuiltinWindowState
- BuiltinWindowState(BuiltinWindowState),
-}
#[derive(Debug)]
pub struct WindowAggState {
@@ -213,9 +355,6 @@ pub struct WindowAggState {
pub last_calculated_index: usize,
/// The offset of the deleted row number
pub offset_pruned_rows: usize,
- /// State of the window function, required to calculate its result
- // For instance, for ROW_NUMBER we keep the row index counter to generate correct result
- pub window_function_state: WindowFunctionState,
/// Stores the results calculated by window frame
pub out_col: ArrayRef,
/// Keeps track of how many rows should be generated to be in sync with input record_batch.
@@ -250,16 +389,12 @@ pub type PartitionWindowAggStates = IndexMap<PartitionKey, WindowState>;
pub type PartitionBatches = IndexMap<PartitionKey, PartitionBatchState>;
impl WindowAggState {
- pub fn new(
- out_type: &DataType,
- window_function_state: WindowFunctionState,
- ) -> Result<Self> {
+ pub fn new(out_type: &DataType) -> Result<Self> {
let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0);
Ok(Self {
window_frame_range: Range { start: 0, end: 0 },
last_calculated_index: 0,
offset_pruned_rows: 0,
- window_function_state,
out_col: empty_out_col,
n_row_result_missing: 0,
is_end: false,