You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/05 11:20:17 UTC
[arrow-datafusion] branch master updated: Custom window frame support extended to built-in window functions (#4078)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 238e17922 Custom window frame support extended to built-in window functions (#4078)
238e17922 is described below
commit 238e179224661f681b20b9ae32f59efd5a3b0713
Author: Mustafa akur <10...@users.noreply.github.com>
AuthorDate: Sat Nov 5 14:20:11 2022 +0300
Custom window frame support extended to built-in window functions (#4078)
* refactor running window
* remove unnecessary changes
* implement suggested changes
* Minor refactors to improve readability
* Refactor according to reviews
* minor changes
* Remove unnecessary into/collect calls
* convert evaluate_inside_range result to ScalarValue
* Simplify evaluate function of BuiltInWindowExpr
Co-authored-by: Mehmet Ozan Kabak <oz...@gmail.com>
---
datafusion/core/src/physical_plan/windows/mod.rs | 1 +
datafusion/core/tests/sql/window.rs | 182 +++++++++
datafusion/physical-expr/src/window/aggregate.rs | 443 ++++-----------------
datafusion/physical-expr/src/window/built_in.rs | 58 ++-
datafusion/physical-expr/src/window/nth_value.rs | 82 ++--
.../src/window/partition_evaluator.rs | 13 +-
datafusion/physical-expr/src/window/window_expr.rs | 211 +++++++++-
7 files changed, 567 insertions(+), 423 deletions(-)
diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs
index 95582b211..bece2a50c 100644
--- a/datafusion/core/src/physical_plan/windows/mod.rs
+++ b/datafusion/core/src/physical_plan/windows/mod.rs
@@ -65,6 +65,7 @@ pub fn create_window_expr(
create_built_in_window_expr(fun, args, input_schema, name)?,
partition_by,
order_by,
+ window_frame,
)),
})
}
diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs
index 9333a7e5a..a36d90c2a 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -1276,3 +1276,185 @@ async fn window_frame_creation() -> Result<()> {
Ok(())
}
+
+#[tokio::test]
+async fn test_window_row_number_aggregate() -> Result<()> {
+ let config = SessionConfig::new();
+ let ctx = SessionContext::with_config(config);
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c8,
+ ROW_NUMBER() OVER(ORDER BY c9) AS rn1,
+ ROW_NUMBER() OVER(ORDER BY c9 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as rn2
+ FROM aggregate_test_100
+ ORDER BY c8
+ LIMIT 5";
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----+-----+-----+",
+ "| c8 | rn1 | rn2 |",
+ "+-----+-----+-----+",
+ "| 102 | 73 | 73 |",
+ "| 299 | 1 | 1 |",
+ "| 363 | 41 | 41 |",
+ "| 417 | 14 | 14 |",
+ "| 794 | 95 | 95 |",
+ "+-----+-----+-----+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_window_cume_dist() -> Result<()> {
+ let config = SessionConfig::new();
+ let ctx = SessionContext::with_config(config);
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c8,
+ CUME_DIST() OVER(ORDER BY c9) as cd1,
+ CUME_DIST() OVER(ORDER BY c9 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as cd2
+ FROM aggregate_test_100
+ ORDER BY c8
+ LIMIT 5";
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----+------+------+",
+ "| c8 | cd1 | cd2 |",
+ "+-----+------+------+",
+ "| 102 | 0.73 | 0.73 |",
+ "| 299 | 0.01 | 0.01 |",
+ "| 363 | 0.41 | 0.41 |",
+ "| 417 | 0.14 | 0.14 |",
+ "| 794 | 0.95 | 0.95 |",
+ "+-----+------+------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_window_rank() -> Result<()> {
+ let config = SessionConfig::new();
+ let ctx = SessionContext::with_config(config);
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c9,
+ RANK() OVER(ORDER BY c1) AS rank1,
+ RANK() OVER(ORDER BY c1 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as rank2,
+ DENSE_RANK() OVER(ORDER BY c1) as dense_rank1,
+ DENSE_RANK() OVER(ORDER BY c1 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as dense_rank2,
+ PERCENT_RANK() OVER(ORDER BY c1) as percent_rank1,
+ PERCENT_RANK() OVER(ORDER BY c1 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as percent_rank2
+ FROM aggregate_test_100
+ ORDER BY c9
+ LIMIT 5";
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----------+-------+-------+-------------+-------------+---------------------+---------------------+",
+ "| c9 | rank1 | rank2 | dense_rank1 | dense_rank2 | percent_rank1 | percent_rank2 |",
+ "+-----------+-------+-------+-------------+-------------+---------------------+---------------------+",
+ "| 28774375 | 80 | 80 | 5 | 5 | 0.797979797979798 | 0.797979797979798 |",
+ "| 63044568 | 62 | 62 | 4 | 4 | 0.6161616161616161 | 0.6161616161616161 |",
+ "| 141047417 | 1 | 1 | 1 | 1 | 0 | 0 |",
+ "| 141680161 | 41 | 41 | 3 | 3 | 0.40404040404040403 | 0.40404040404040403 |",
+ "| 145294611 | 1 | 1 | 1 | 1 | 0 | 0 |",
+ "+-----------+-------+-------+-------------+-------------+---------------------+---------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_lag_lead() -> Result<()> {
+ let config = SessionConfig::new();
+ let ctx = SessionContext::with_config(config);
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c9,
+ LAG(c9, 2, 10101) OVER(ORDER BY c9) as lag1,
+ LAG(c9, 2, 10101) OVER(ORDER BY c9 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lag2,
+ LEAD(c9, 2, 10101) OVER(ORDER BY c9) as lead1,
+ LEAD(c9, 2, 10101) OVER(ORDER BY c9 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lead2
+ FROM aggregate_test_100
+ ORDER BY c9
+ LIMIT 5";
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----------+-----------+-----------+-----------+-----------+",
+ "| c9 | lag1 | lag2 | lead1 | lead2 |",
+ "+-----------+-----------+-----------+-----------+-----------+",
+ "| 28774375 | 10101 | 10101 | 141047417 | 141047417 |",
+ "| 63044568 | 10101 | 10101 | 141680161 | 141680161 |",
+ "| 141047417 | 28774375 | 28774375 | 145294611 | 145294611 |",
+ "| 141680161 | 63044568 | 63044568 | 225513085 | 225513085 |",
+ "| 145294611 | 141047417 | 141047417 | 243203849 | 243203849 |",
+ "+-----------+-----------+-----------+-----------+-----------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_window_frame_first_value_last_value_aggregate() -> Result<()> {
+ let config = SessionConfig::new();
+ let ctx = SessionContext::with_config(config);
+ register_aggregate_csv(&ctx).await?;
+
+ let sql = "SELECT
+ FIRST_VALUE(c4) OVER(ORDER BY c9 ASC ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING) as first_value1,
+ FIRST_VALUE(c4) OVER(ORDER BY c9 ASC ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) as first_value2,
+ LAST_VALUE(c4) OVER(ORDER BY c9 ASC ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING) as last_value1,
+ LAST_VALUE(c4) OVER(ORDER BY c9 ASC ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING) as last_value2
+ FROM aggregate_test_100
+ ORDER BY c9
+ LIMIT 5";
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+--------------+--------------+-------------+-------------+",
+ "| first_value1 | first_value2 | last_value1 | last_value2 |",
+ "+--------------+--------------+-------------+-------------+",
+ "| -16110 | -16110 | 3917 | -1114 |",
+ "| -16110 | -16110 | -16974 | 15673 |",
+ "| -16110 | -16110 | -1114 | 13630 |",
+ "| -16110 | 3917 | 15673 | -13217 |",
+ "| -16110 | -16974 | 13630 | 20690 |",
+ "+--------------+--------------+-------------+-------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_window_frame_nth_value_aggregate() -> Result<()> {
+ let config = SessionConfig::new();
+ let ctx = SessionContext::with_config(config);
+ register_aggregate_csv(&ctx).await?;
+
+ let sql = "SELECT
+ NTH_VALUE(c4, 3) OVER(ORDER BY c9 ASC ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING) as nth_value1,
+ NTH_VALUE(c4, 2) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING) as nth_value2
+ FROM aggregate_test_100
+ ORDER BY c9
+ LIMIT 5";
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+------------+------------+",
+ "| nth_value1 | nth_value2 |",
+ "+------------+------------+",
+ "| | 3917 |",
+ "| -16974 | 3917 |",
+ "| -16974 | -16974 |",
+ "| -1114 | -1114 |",
+ "| 15673 | 15673 |",
+ "+------------+------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs
index 80cb4d10c..e6c754387 100644
--- a/datafusion/physical-expr/src/window/aggregate.rs
+++ b/datafusion/physical-expr/src/window/aggregate.rs
@@ -18,21 +18,17 @@
//! Physical exec for aggregate window function expressions.
use std::any::Any;
-use std::cmp::min;
use std::iter::IntoIterator;
-use std::ops::Range;
use std::sync::Arc;
use arrow::array::Array;
-use arrow::compute::{concat, SortOptions};
+use arrow::compute::SortOptions;
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, datatypes::Field};
-use datafusion_common::bisect::bisect;
use datafusion_common::Result;
-use datafusion_common::{DataFusionError, ScalarValue};
-use datafusion_expr::{Accumulator, WindowFrameBound};
-use datafusion_expr::{WindowFrame, WindowFrameUnits};
+use datafusion_common::ScalarValue;
+use datafusion_expr::WindowFrame;
use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
use crate::{window::WindowExpr, AggregateExpr};
@@ -61,23 +57,6 @@ impl AggregateWindowExpr {
window_frame,
}
}
-
- /// create a new accumulator based on the underlying aggregation function
- fn create_accumulator(&self) -> Result<AggregateWindowAccumulator> {
- let accumulator = self.aggregate.create_accumulator()?;
- let window_frame = self.window_frame.clone();
- let partition_by = self.partition_by().to_vec();
- let order_by = self.order_by.to_vec();
- let field = self.aggregate.field()?;
-
- Ok(AggregateWindowAccumulator {
- accumulator,
- window_frame,
- partition_by,
- order_by,
- field,
- })
- }
}
/// peer based evaluation based on the fact that batch is pre-sorted given the sort columns
@@ -103,368 +82,86 @@ impl WindowExpr for AggregateWindowExpr {
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
- let num_rows = batch.num_rows();
+ let partition_columns = self.partition_columns(batch)?;
let partition_points =
- self.evaluate_partition_points(num_rows, &self.partition_columns(batch)?)?;
+ self.evaluate_partition_points(batch.num_rows(), &partition_columns)?;
let values = self.evaluate_args(batch)?;
+ let sort_options: Vec<SortOptions> =
+ self.order_by.iter().map(|o| o.options).collect();
let columns = self.sort_columns(batch)?;
- let array_refs: Vec<&ArrayRef> = columns.iter().map(|s| &s.values).collect();
+ let order_columns: Vec<&ArrayRef> = columns.iter().map(|s| &s.values).collect();
// Sort values, this will make the same partitions consecutive. Also, within the partition
// range, values will be sorted.
- let results = partition_points
- .iter()
- .map(|partition_range| {
- let mut window_accumulators = self.create_accumulator()?;
- Ok(vec![window_accumulators.scan(
- &values,
- &array_refs,
- partition_range,
- )?])
- })
- .collect::<Result<Vec<Vec<ArrayRef>>>>()?
- .into_iter()
- .flatten()
- .collect::<Vec<ArrayRef>>();
- let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
- concat(&results).map_err(DataFusionError::ArrowError)
- }
-
- fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
- &self.partition_by
- }
-
- fn order_by(&self) -> &[PhysicalSortExpr] {
- &self.order_by
- }
-}
-
-fn calculate_index_of_row<const BISECT_SIDE: bool, const SEARCH_SIDE: bool>(
- range_columns: &[ArrayRef],
- sort_options: &[SortOptions],
- idx: usize,
- delta: Option<&ScalarValue>,
-) -> Result<usize> {
- let current_row_values = range_columns
- .iter()
- .map(|col| ScalarValue::try_from_array(col, idx))
- .collect::<Result<Vec<ScalarValue>>>()?;
- let end_range = if let Some(delta) = delta {
- let is_descending: bool = sort_options
- .first()
- .ok_or_else(|| DataFusionError::Internal("Array is empty".to_string()))?
- .descending;
-
- current_row_values
- .iter()
- .map(|value| {
- if value.is_null() {
- return Ok(value.clone());
- }
- if SEARCH_SIDE == is_descending {
- // TODO: Handle positive overflows
- value.add(delta)
- } else if value.is_unsigned() && value < delta {
- // NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue.
- // If we decide to implement a "default" construction mechanism for ScalarValue,
- // change the following statement to use that.
- value.sub(value)
+ let order_bys = &order_columns[self.partition_by.len()..];
+ let window_frame = if !order_bys.is_empty() && self.window_frame.is_none() {
+ // OVER (ORDER BY a) case
+ // We create an implicit window for ORDER BY.
+ Some(Arc::new(WindowFrame::default()))
+ } else {
+ self.window_frame.clone()
+ };
+ let mut row_wise_results: Vec<ScalarValue> = vec![];
+ for partition_range in &partition_points {
+ let mut accumulator = self.aggregate.create_accumulator()?;
+ let length = partition_range.end - partition_range.start;
+ let slice_order_bys = order_bys
+ .iter()
+ .map(|v| v.slice(partition_range.start, length))
+ .collect::<Vec<_>>();
+ let value_slice = values
+ .iter()
+ .map(|v| v.slice(partition_range.start, length))
+ .collect::<Vec<_>>();
+
+ let mut last_range: (usize, usize) = (0, 0);
+
+ // We iterate on each row to perform a running calculation.
+ // First, cur_range is calculated, then it is compared with last_range.
+ for i in 0..length {
+ let cur_range = self.calculate_range(
+ &window_frame,
+ &slice_order_bys,
+ &sort_options,
+ length,
+ i,
+ )?;
+ let value = if cur_range.0 == cur_range.1 {
+ // We produce None if the window is empty.
+ ScalarValue::try_from(self.aggregate.field()?.data_type())?
} else {
- // TODO: Handle negative overflows
- value.sub(delta)
- }
- })
- .collect::<Result<Vec<ScalarValue>>>()?
- } else {
- current_row_values
- };
- // `BISECT_SIDE` true means bisect_left, false means bisect_right
- bisect::<BISECT_SIDE>(range_columns, &end_range, sort_options)
-}
-
-/// We use start and end bounds to calculate current row's starting and ending range.
-/// This function supports different modes, but we currently do not support window calculation for GROUPS inside window frames.
-fn calculate_current_window(
- window_frame: &WindowFrame,
- range_columns: &[ArrayRef],
- sort_options: &[SortOptions],
- length: usize,
- idx: usize,
-) -> Result<(usize, usize)> {
- match window_frame.units {
- WindowFrameUnits::Range => {
- let start = match &window_frame.start_bound {
- WindowFrameBound::Preceding(n) => {
- if n.is_null() {
- // UNBOUNDED PRECEDING
- Ok(0)
- } else {
- calculate_index_of_row::<true, true>(
- range_columns,
- sort_options,
- idx,
- Some(n),
- )
+ // Accumulate any new rows that have entered the window:
+ let update_bound = cur_range.1 - last_range.1;
+ if update_bound > 0 {
+ let update: Vec<ArrayRef> = value_slice
+ .iter()
+ .map(|v| v.slice(last_range.1, update_bound))
+ .collect();
+ accumulator.update_batch(&update)?
}
- }
- WindowFrameBound::CurrentRow => calculate_index_of_row::<true, true>(
- range_columns,
- sort_options,
- idx,
- None,
- ),
- WindowFrameBound::Following(n) => calculate_index_of_row::<true, false>(
- range_columns,
- sort_options,
- idx,
- Some(n),
- ),
- };
- let end = match &window_frame.end_bound {
- WindowFrameBound::Preceding(n) => calculate_index_of_row::<false, true>(
- range_columns,
- sort_options,
- idx,
- Some(n),
- ),
- WindowFrameBound::CurrentRow => calculate_index_of_row::<false, false>(
- range_columns,
- sort_options,
- idx,
- None,
- ),
- WindowFrameBound::Following(n) => {
- if n.is_null() {
- // UNBOUNDED FOLLOWING
- Ok(length)
- } else {
- calculate_index_of_row::<false, false>(
- range_columns,
- sort_options,
- idx,
- Some(n),
- )
+ // Remove rows that have now left the window:
+ let retract_bound = cur_range.0 - last_range.0;
+ if retract_bound > 0 {
+ let retract: Vec<ArrayRef> = value_slice
+ .iter()
+ .map(|v| v.slice(last_range.0, retract_bound))
+ .collect();
+ accumulator.retract_batch(&retract)?
}
- }
- };
- Ok((start?, end?))
- }
- WindowFrameUnits::Rows => {
- let start = match window_frame.start_bound {
- // UNBOUNDED PRECEDING
- WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => Ok(0),
- WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
- if idx >= n as usize {
- Ok(idx - n as usize)
- } else {
- Ok(0)
- }
- }
- WindowFrameBound::Preceding(_) => {
- Err(DataFusionError::Internal("Rows should be Uint".to_string()))
- }
- WindowFrameBound::CurrentRow => Ok(idx),
- // UNBOUNDED FOLLOWING
- WindowFrameBound::Following(ScalarValue::UInt64(None)) => {
- Err(DataFusionError::Internal(format!(
- "Frame start cannot be UNBOUNDED FOLLOWING '{:?}'",
- window_frame
- )))
- }
- WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
- Ok(min(idx + n as usize, length))
- }
- WindowFrameBound::Following(_) => {
- Err(DataFusionError::Internal("Rows should be Uint".to_string()))
- }
- };
- let end = match window_frame.end_bound {
- // UNBOUNDED PRECEDING
- WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => {
- Err(DataFusionError::Internal(format!(
- "Frame end cannot be UNBOUNDED PRECEDING '{:?}'",
- window_frame
- )))
- }
- WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
- if idx >= n as usize {
- Ok(idx - n as usize + 1)
- } else {
- Ok(0)
- }
- }
- WindowFrameBound::Preceding(_) => {
- Err(DataFusionError::Internal("Rows should be Uint".to_string()))
- }
- WindowFrameBound::CurrentRow => Ok(idx + 1),
- // UNBOUNDED FOLLOWING
- WindowFrameBound::Following(ScalarValue::UInt64(None)) => Ok(length),
- WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
- Ok(min(idx + n as usize + 1, length))
- }
- WindowFrameBound::Following(_) => {
- Err(DataFusionError::Internal("Rows should be Uint".to_string()))
- }
- };
- Ok((start?, end?))
- }
- WindowFrameUnits::Groups => Err(DataFusionError::NotImplemented(
- "Window frame for groups is not implemented".to_string(),
- )),
- }
-}
-
-/// Aggregate window accumulator utilizes the accumulator from aggregation and do a accumulative sum
-/// across evaluation arguments based on peer equivalences. It uses many information to calculate
-/// correct running window.
-#[derive(Debug)]
-struct AggregateWindowAccumulator {
- accumulator: Box<dyn Accumulator>,
- window_frame: Option<Arc<WindowFrame>>,
- partition_by: Vec<Arc<dyn PhysicalExpr>>,
- order_by: Vec<PhysicalSortExpr>,
- field: Field,
-}
-
-impl AggregateWindowAccumulator {
- /// This function calculates the aggregation on all rows in `value_slice`.
- /// Returns an array of size `length`.
- fn calculate_whole_table(
- &mut self,
- value_slice: &[ArrayRef],
- length: usize,
- ) -> Result<ArrayRef> {
- self.accumulator.update_batch(value_slice)?;
- let value = self.accumulator.evaluate()?;
- Ok(value.to_array_of_size(length))
- }
-
- /// This function calculates the running window logic for the rows in `value_range` of `value_slice`.
- /// We maintain the accumulator state via `update_batch` and `retract_batch` functions.
- /// Note that not all aggregators implement `retract_batch` just yet.
- fn calculate_running_window(
- &mut self,
- value_slice: &[ArrayRef],
- order_bys: &[&ArrayRef],
- value_range: &Range<usize>,
- ) -> Result<ArrayRef> {
- // We iterate on each row to perform a running calculation.
- // First, cur_range is calculated, then it is compared with last_range.
- let length = value_range.end - value_range.start;
- let slice_order_columns = order_bys
- .iter()
- .map(|v| v.slice(value_range.start, length))
- .collect::<Vec<_>>();
- let sort_options: Vec<SortOptions> =
- self.order_by.iter().map(|o| o.options).collect();
-
- let updated_zero_offset_value_range = Range {
- start: 0,
- end: length,
- };
- let mut row_wise_results: Vec<ScalarValue> = vec![];
- let mut last_range: (usize, usize) = (
- updated_zero_offset_value_range.start,
- updated_zero_offset_value_range.start,
- );
-
- for i in 0..length {
- let window_frame = self.window_frame.as_ref().ok_or_else(|| {
- DataFusionError::Internal(
- "Window frame cannot be empty to calculate window ranges".to_string(),
- )
- })?;
- let cur_range = calculate_current_window(
- window_frame,
- &slice_order_columns,
- &sort_options,
- length,
- i,
- )?;
-
- if cur_range.0 == cur_range.1 {
- // We produce None if the window is empty.
- row_wise_results.push(ScalarValue::try_from(self.field.data_type())?)
- } else {
- // Accumulate any new rows that have entered the window:
- let update_bound = cur_range.1 - last_range.1;
- if update_bound > 0 {
- let update: Vec<ArrayRef> = value_slice
- .iter()
- .map(|v| v.slice(last_range.1, update_bound))
- .collect();
- self.accumulator.update_batch(&update)?
- }
- // Remove rows that have now left the window:
- let retract_bound = cur_range.0 - last_range.0;
- if retract_bound > 0 {
- let retract: Vec<ArrayRef> = value_slice
- .iter()
- .map(|v| v.slice(last_range.0, retract_bound))
- .collect();
- self.accumulator.retract_batch(&retract)?
- }
- row_wise_results.push(self.accumulator.evaluate()?);
+ accumulator.evaluate()?
+ };
+ row_wise_results.push(value);
+ last_range = cur_range;
}
- last_range = cur_range;
}
ScalarValue::iter_to_array(row_wise_results.into_iter())
}
- fn scan(
- &mut self,
- values: &[ArrayRef],
- order_bys: &[&ArrayRef],
- value_range: &Range<usize>,
- ) -> Result<ArrayRef> {
- if value_range.is_empty() {
- return Err(DataFusionError::Internal(
- "Value range cannot be empty".to_owned(),
- ));
- }
- let length = value_range.end - value_range.start;
- let value_slice = values
- .iter()
- .map(|v| v.slice(value_range.start, length))
- .collect::<Vec<_>>();
- let order_columns = &order_bys[self.partition_by.len()..order_bys.len()].to_vec();
- match (&order_columns[..], &self.window_frame) {
- ([], None) => {
- // OVER () case
- self.calculate_whole_table(&value_slice, length)
- }
- ([column, ..], None) => {
- // OVER (ORDER BY a) case
- // We create an implicit window for ORDER BY.
- let empty_bound = ScalarValue::try_from(column.data_type())?;
- self.window_frame = Some(Arc::new(WindowFrame {
- units: WindowFrameUnits::Range,
- start_bound: WindowFrameBound::Preceding(empty_bound),
- end_bound: WindowFrameBound::CurrentRow,
- }));
- self.calculate_running_window(&value_slice, order_columns, value_range)
- }
- ([], Some(frame)) => {
- match frame.units {
- WindowFrameUnits::Range => {
- // OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) case
- self.calculate_whole_table(&value_slice, length)
- }
- WindowFrameUnits::Rows => {
- // OVER (ROWS BETWEEN X PRECEDING AND Y FOLLOWING) case
- self.calculate_running_window(
- &value_slice,
- order_bys,
- value_range,
- )
- }
- WindowFrameUnits::Groups => Err(DataFusionError::NotImplemented(
- "Window frame for groups is not implemented".to_string(),
- )),
- }
- }
- // OVER (ORDER BY a ROWS/RANGE BETWEEN X PRECEDING AND Y FOLLOWING) case
- _ => self.calculate_running_window(&value_slice, order_columns, value_range),
- }
+ fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
+ &self.partition_by
+ }
+
+ fn order_by(&self) -> &[PhysicalSortExpr] {
+ &self.order_by
}
}
diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs
index 2fa1f808f..e4e377175 100644
--- a/datafusion/physical-expr/src/window/built_in.rs
+++ b/datafusion/physical-expr/src/window/built_in.rs
@@ -20,12 +20,15 @@
use super::BuiltInWindowFunctionExpr;
use super::WindowExpr;
use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
-use arrow::compute::concat;
+use arrow::array::Array;
+use arrow::compute::{concat, SortOptions};
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, datatypes::Field};
use datafusion_common::DataFusionError;
use datafusion_common::Result;
+use datafusion_expr::WindowFrame;
use std::any::Any;
+use std::ops::Range;
use std::sync::Arc;
/// A window expr that takes the form of a built in window function
@@ -34,6 +37,7 @@ pub struct BuiltInWindowExpr {
expr: Arc<dyn BuiltInWindowFunctionExpr>,
partition_by: Vec<Arc<dyn PhysicalExpr>>,
order_by: Vec<PhysicalSortExpr>,
+ window_frame: Option<Arc<WindowFrame>>,
}
impl BuiltInWindowExpr {
@@ -42,11 +46,13 @@ impl BuiltInWindowExpr {
expr: Arc<dyn BuiltInWindowFunctionExpr>,
partition_by: &[Arc<dyn PhysicalExpr>],
order_by: &[PhysicalSortExpr],
+ window_frame: Option<Arc<WindowFrame>>,
) -> Self {
Self {
expr,
partition_by: partition_by.to_vec(),
order_by: order_by.to_vec(),
+ window_frame,
}
}
}
@@ -80,11 +86,55 @@ impl WindowExpr for BuiltInWindowExpr {
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
let evaluator = self.expr.create_evaluator(batch)?;
let num_rows = batch.num_rows();
+ let partition_columns = self.partition_columns(batch)?;
let partition_points =
- self.evaluate_partition_points(num_rows, &self.partition_columns(batch)?)?;
- let results = if evaluator.include_rank() {
+ self.evaluate_partition_points(num_rows, &partition_columns)?;
+
+ let results = if evaluator.uses_window_frame() {
+ let sort_options: Vec<SortOptions> =
+ self.order_by.iter().map(|o| o.options).collect();
+ let columns = self.sort_columns(batch)?;
+ let order_columns: Vec<&ArrayRef> =
+ columns.iter().map(|s| &s.values).collect();
+ // Sort values, this will make the same partitions consecutive. Also, within the partition
+ // range, values will be sorted.
+ let order_bys = &order_columns[self.partition_by.len()..];
+ let window_frame = if !order_bys.is_empty() && self.window_frame.is_none() {
+ // OVER (ORDER BY a) case
+ // We create an implicit window for ORDER BY.
+ Some(Arc::new(WindowFrame::default()))
+ } else {
+ self.window_frame.clone()
+ };
+ let mut row_wise_results = vec![];
+ for partition_range in &partition_points {
+ let length = partition_range.end - partition_range.start;
+ let slice_order_bys = order_bys
+ .iter()
+ .map(|v| v.slice(partition_range.start, length))
+ .collect::<Vec<_>>();
+ // We iterate on each row to calculate window frame range and and window function result
+ for idx in 0..length {
+ let range = self.calculate_range(
+ &window_frame,
+ &slice_order_bys,
+ &sort_options,
+ num_rows,
+ idx,
+ )?;
+ let range = Range {
+ start: partition_range.start + range.0,
+ end: partition_range.start + range.1,
+ };
+ let value = evaluator.evaluate_inside_range(range)?;
+ row_wise_results.push(value.to_array());
+ }
+ }
+ row_wise_results
+ } else if evaluator.include_rank() {
+ let columns = self.sort_columns(batch)?;
let sort_partition_points =
- self.evaluate_partition_points(num_rows, &self.sort_columns(batch)?)?;
+ self.evaluate_partition_points(num_rows, &columns)?;
evaluator.evaluate_with_rank(partition_points, sort_partition_points)?
} else {
evaluator.evaluate(partition_points)?
diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs
index e0a6b2bd7..14ce53621 100644
--- a/datafusion/physical-expr/src/window/nth_value.rs
+++ b/datafusion/physical-expr/src/window/nth_value.rs
@@ -21,14 +21,12 @@
use crate::window::partition_evaluator::PartitionEvaluator;
use crate::window::BuiltInWindowFunctionExpr;
use crate::PhysicalExpr;
-use arrow::array::{new_null_array, ArrayRef};
-use arrow::compute::kernels::window::shift;
+use arrow::array::{Array, ArrayRef};
use arrow::datatypes::{DataType, Field};
use arrow::record_batch::RecordBatch;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use std::any::Any;
-use std::iter;
use std::ops::Range;
use std::sync::Arc;
@@ -142,7 +140,7 @@ pub(crate) struct NthValueEvaluator {
}
impl PartitionEvaluator for NthValueEvaluator {
- fn include_rank(&self) -> bool {
+ fn uses_window_frame(&self) -> bool {
true
}
@@ -150,45 +148,19 @@ impl PartitionEvaluator for NthValueEvaluator {
unreachable!("first, last, and nth_value evaluation must be called with evaluate_partition_with_rank")
}
- fn evaluate_partition_with_rank(
- &self,
- partition: Range<usize>,
- ranks_in_partition: &[Range<usize>],
- ) -> Result<ArrayRef> {
+ fn evaluate_inside_range(&self, range: Range<usize>) -> Result<ScalarValue> {
let arr = &self.values[0];
- let num_rows = partition.end - partition.start;
+ let n_range = range.end - range.start;
match self.kind {
- NthValueKind::First => {
- let value = ScalarValue::try_from_array(arr, partition.start)?;
- Ok(value.to_array_of_size(num_rows))
- }
- NthValueKind::Last => {
- // because the default window frame is between unbounded preceding and current
- // row with peer evaluation, hence the last rows expands until the end of the peers
- let values = ranks_in_partition
- .iter()
- .map(|range| {
- let len = range.end - range.start;
- let value = ScalarValue::try_from_array(arr, range.end - 1)?;
- Ok(iter::repeat(value).take(len))
- })
- .collect::<Result<Vec<_>>>()?
- .into_iter()
- .flatten();
- ScalarValue::iter_to_array(values)
- }
+ NthValueKind::First => ScalarValue::try_from_array(arr, range.start),
+ NthValueKind::Last => ScalarValue::try_from_array(arr, range.end - 1),
NthValueKind::Nth(n) => {
+ // We are certain that n > 0.
let index = (n as usize) - 1;
- if index >= num_rows {
- Ok(new_null_array(arr.data_type(), num_rows))
+ if index >= n_range {
+ ScalarValue::try_from(arr.data_type())
} else {
- let value =
- ScalarValue::try_from_array(arr, partition.start + index)?;
- let arr = value.to_array_of_size(num_rows);
- // because the default window frame is between unbounded preceding and current
- // row, hence the shift because for values with indices < index they should be
- // null. This changes when window frames other than default is implemented
- shift(arr.as_ref(), index as i64).map_err(DataFusionError::ArrowError)
+ ScalarValue::try_from_array(arr, range.start + index)
}
}
}
@@ -208,11 +180,21 @@ mod tests {
let values = vec![arr];
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
- let result = expr
- .create_evaluator(&batch)?
- .evaluate_with_rank(vec![0..8], vec![0..8])?;
- assert_eq!(1, result.len());
- let result = result[0].as_any().downcast_ref::<Int32Array>().unwrap();
+ let mut ranges: Vec<Range<usize>> = vec![];
+ for i in 0..8 {
+ ranges.push(Range {
+ start: 0,
+ end: i + 1,
+ })
+ }
+ let evaluator = expr.create_evaluator(&batch)?;
+ let result = ranges
+ .into_iter()
+ .map(|range| evaluator.evaluate_inside_range(range))
+ .into_iter()
+ .collect::<Result<Vec<ScalarValue>>>()?;
+ let result = ScalarValue::iter_to_array(result.into_iter())?;
+ let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(expected, *result);
Ok(())
}
@@ -235,7 +217,19 @@ mod tests {
Arc::new(Column::new("arr", 0)),
DataType::Int32,
);
- test_i32_result(last_value, Int32Array::from_iter_values(vec![8; 8]))?;
+ test_i32_result(
+ last_value,
+ Int32Array::from(vec![
+ Some(1),
+ Some(-2),
+ Some(3),
+ Some(-4),
+ Some(5),
+ Some(-6),
+ Some(7),
+ Some(8),
+ ]),
+ )?;
Ok(())
}
diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs
index c3a88367a..4ecfd87a9 100644
--- a/datafusion/physical-expr/src/window/partition_evaluator.rs
+++ b/datafusion/physical-expr/src/window/partition_evaluator.rs
@@ -18,8 +18,8 @@
//! partition evaluation module
use arrow::array::ArrayRef;
-use datafusion_common::DataFusionError;
use datafusion_common::Result;
+use datafusion_common::{DataFusionError, ScalarValue};
use std::ops::Range;
/// Given a partition range, and the full list of sort partition points, given that the sort
@@ -46,6 +46,10 @@ pub trait PartitionEvaluator {
false
}
+ fn uses_window_frame(&self) -> bool {
+ false
+ }
+
/// evaluate the partition evaluator against the partitions
fn evaluate(&self, partition_points: Vec<Range<usize>>) -> Result<Vec<ArrayRef>> {
partition_points
@@ -83,4 +87,11 @@ pub trait PartitionEvaluator {
"evaluate_partition_with_rank is not implemented by default".into(),
))
}
+
+ /// evaluate window function result inside given range
+ fn evaluate_inside_range(&self, _range: Range<usize>) -> Result<ScalarValue> {
+ Err(DataFusionError::NotImplemented(
+ "evaluate_inside_range is not implemented by default".into(),
+ ))
+ }
}
diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs
index 67caba51d..9c4b1b179 100644
--- a/datafusion/physical-expr/src/window/window_expr.rs
+++ b/datafusion/physical-expr/src/window/window_expr.rs
@@ -20,12 +20,17 @@ use arrow::compute::kernels::partition::lexicographical_partition_ranges;
use arrow::compute::kernels::sort::{SortColumn, SortOptions};
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, datatypes::Field};
-use datafusion_common::{DataFusionError, Result};
+use datafusion_common::bisect::bisect;
+use datafusion_common::{DataFusionError, Result, ScalarValue};
use std::any::Any;
+use std::cmp::min;
use std::fmt::Debug;
use std::ops::Range;
use std::sync::Arc;
+use datafusion_expr::WindowFrameBound;
+use datafusion_expr::{WindowFrame, WindowFrameUnits};
+
/// A window expression that:
/// * knows its resulting field
pub trait WindowExpr: Send + Sync + Debug {
@@ -110,4 +115,208 @@ pub trait WindowExpr: Send + Sync + Debug {
sort_columns.extend(order_by_columns);
Ok(sort_columns)
}
+
+ /// We use start and end bounds to calculate current row's starting and ending range.
+ /// This function supports different modes, but we currently do not support window calculation for GROUPS inside window frames.
+ fn calculate_range(
+ &self,
+ window_frame: &Option<Arc<WindowFrame>>,
+ range_columns: &[ArrayRef],
+ sort_options: &[SortOptions],
+ length: usize,
+ idx: usize,
+ ) -> Result<(usize, usize)> {
+ if let Some(window_frame) = window_frame {
+ match window_frame.units {
+ WindowFrameUnits::Range => {
+ let start = match &window_frame.start_bound {
+ // UNBOUNDED PRECEDING
+ WindowFrameBound::Preceding(n) => {
+ if n.is_null() {
+ 0
+ } else {
+ calculate_index_of_row::<true, true>(
+ range_columns,
+ sort_options,
+ idx,
+ Some(n),
+ )?
+ }
+ }
+ WindowFrameBound::CurrentRow => {
+ if range_columns.is_empty() {
+ 0
+ } else {
+ calculate_index_of_row::<true, true>(
+ range_columns,
+ sort_options,
+ idx,
+ None,
+ )?
+ }
+ }
+ WindowFrameBound::Following(n) => {
+ calculate_index_of_row::<true, false>(
+ range_columns,
+ sort_options,
+ idx,
+ Some(n),
+ )?
+ }
+ };
+ let end = match &window_frame.end_bound {
+ WindowFrameBound::Preceding(n) => {
+ calculate_index_of_row::<false, true>(
+ range_columns,
+ sort_options,
+ idx,
+ Some(n),
+ )?
+ }
+ WindowFrameBound::CurrentRow => {
+ if range_columns.is_empty() {
+ length
+ } else {
+ calculate_index_of_row::<false, false>(
+ range_columns,
+ sort_options,
+ idx,
+ None,
+ )?
+ }
+ }
+ WindowFrameBound::Following(n) => {
+ if n.is_null() {
+ // UNBOUNDED FOLLOWING
+ length
+ } else {
+ calculate_index_of_row::<false, false>(
+ range_columns,
+ sort_options,
+ idx,
+ Some(n),
+ )?
+ }
+ }
+ };
+ Ok((start, end))
+ }
+ WindowFrameUnits::Rows => {
+ let start = match window_frame.start_bound {
+ // UNBOUNDED PRECEDING
+ WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0,
+ WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
+ if idx >= n as usize {
+ idx - n as usize
+ } else {
+ 0
+ }
+ }
+ WindowFrameBound::Preceding(_) => {
+ return Err(DataFusionError::Internal(
+ "Rows should be Uint".to_string(),
+ ))
+ }
+ WindowFrameBound::CurrentRow => idx,
+ // UNBOUNDED FOLLOWING
+ WindowFrameBound::Following(ScalarValue::UInt64(None)) => {
+ return Err(DataFusionError::Internal(format!(
+ "Frame start cannot be UNBOUNDED FOLLOWING '{:?}'",
+ window_frame
+ )))
+ }
+ WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
+ min(idx + n as usize, length)
+ }
+ WindowFrameBound::Following(_) => {
+ return Err(DataFusionError::Internal(
+ "Rows should be Uint".to_string(),
+ ))
+ }
+ };
+ let end = match window_frame.end_bound {
+ // UNBOUNDED PRECEDING
+ WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => {
+ return Err(DataFusionError::Internal(format!(
+ "Frame end cannot be UNBOUNDED PRECEDING '{:?}'",
+ window_frame
+ )))
+ }
+ WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
+ if idx >= n as usize {
+ idx - n as usize + 1
+ } else {
+ 0
+ }
+ }
+ WindowFrameBound::Preceding(_) => {
+ return Err(DataFusionError::Internal(
+ "Rows should be Uint".to_string(),
+ ))
+ }
+ WindowFrameBound::CurrentRow => idx + 1,
+ // UNBOUNDED FOLLOWING
+ WindowFrameBound::Following(ScalarValue::UInt64(None)) => length,
+ WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
+ min(idx + n as usize + 1, length)
+ }
+ WindowFrameBound::Following(_) => {
+ return Err(DataFusionError::Internal(
+ "Rows should be Uint".to_string(),
+ ))
+ }
+ };
+ Ok((start, end))
+ }
+ WindowFrameUnits::Groups => Err(DataFusionError::NotImplemented(
+ "Window frame for groups is not implemented".to_string(),
+ )),
+ }
+ } else {
+ Ok((0, length))
+ }
+ }
+}
+
+fn calculate_index_of_row<const BISECT_SIDE: bool, const SEARCH_SIDE: bool>(
+ range_columns: &[ArrayRef],
+ sort_options: &[SortOptions],
+ idx: usize,
+ delta: Option<&ScalarValue>,
+) -> Result<usize> {
+ let current_row_values = range_columns
+ .iter()
+ .map(|col| ScalarValue::try_from_array(col, idx))
+ .collect::<Result<Vec<ScalarValue>>>()?;
+ let end_range = if let Some(delta) = delta {
+ let is_descending: bool = sort_options
+ .first()
+ .ok_or_else(|| DataFusionError::Internal("Array is empty".to_string()))?
+ .descending;
+
+ current_row_values
+ .iter()
+ .map(|value| {
+ if value.is_null() {
+ return Ok(value.clone());
+ }
+ if SEARCH_SIDE == is_descending {
+ // TODO: Handle positive overflows
+ value.add(delta)
+ } else if value.is_unsigned() && value < delta {
+ // NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue.
+ // If we decide to implement a "default" construction mechanism for ScalarValue,
+ // change the following statement to use that.
+ value.sub(value)
+ } else {
+ // TODO: Handle negative overflows
+ value.sub(delta)
+ }
+ })
+ .collect::<Result<Vec<ScalarValue>>>()?
+ } else {
+ current_row_values
+ };
+ // `BISECT_SIDE` true means bisect_left, false means bisect_right
+ bisect::<BISECT_SIDE>(range_columns, &end_range, sort_options)
}