You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ak...@apache.org on 2023/06/21 06:56:12 UTC
[arrow-datafusion] branch main updated: Move `PartitionEvaluator` and window_state structures to `datafusion_expr` crate (#6690)
This is an automated email from the ASF dual-hosted git repository.
akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 3c304e0519 Move `PartitionEvaluator` and window_state structures to `datafusion_expr` crate (#6690)
3c304e0519 is described below
commit 3c304e05194c4e14e020e27a0d00578ed3a749fb
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Wed Jun 21 02:56:05 2023 -0400
Move `PartitionEvaluator` and window_state structures to `datafusion_expr` crate (#6690)
* Move `PartitonEvaluator` and window_state structures to `datafusion_expr` crate
* Update docs
---
.../windows/bounded_window_agg_exec.rs | 4 +-
datafusion/expr/src/lib.rs | 3 +
.../src/window => expr/src}/partition_evaluator.rs | 11 +-
.../src/window_state.rs} | 131 +++++++++++++++++----
datafusion/physical-expr/src/window/built_in.rs | 7 +-
.../src/window/built_in_window_function_expr.rs | 2 +-
datafusion/physical-expr/src/window/cume_dist.rs | 2 +-
datafusion/physical-expr/src/window/lead_lag.rs | 5 +-
datafusion/physical-expr/src/window/mod.rs | 4 -
datafusion/physical-expr/src/window/nth_value.rs | 5 +-
datafusion/physical-expr/src/window/ntile.rs | 2 +-
datafusion/physical-expr/src/window/rank.rs | 5 +-
datafusion/physical-expr/src/window/row_number.rs | 2 +-
datafusion/physical-expr/src/window/window_expr.rs | 102 +---------------
14 files changed, 142 insertions(+), 143 deletions(-)
diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs
index 2512776e8d..9c86abec81 100644
--- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs
+++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs
@@ -41,6 +41,7 @@ use arrow::{
datatypes::{Schema, SchemaBuilder, SchemaRef},
record_batch::RecordBatch,
};
+use datafusion_expr::window_state::{PartitionBatchState, WindowAggState};
use futures::stream::Stream;
use futures::{ready, StreamExt};
use hashbrown::raw::RawTable;
@@ -62,8 +63,7 @@ use datafusion_common::DataFusionError;
use datafusion_expr::ColumnarValue;
use datafusion_physical_expr::hash_utils::create_hashes;
use datafusion_physical_expr::window::{
- PartitionBatchState, PartitionBatches, PartitionKey, PartitionWindowAggStates,
- WindowAggState, WindowState,
+ PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowState,
};
use datafusion_physical_expr::{
EquivalenceProperties, OrderingEquivalenceProperties, PhysicalExpr,
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index 1675afb9c9..ccb9728877 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -41,6 +41,7 @@ mod literal;
pub mod logical_plan;
mod nullif;
mod operator;
+mod partition_evaluator;
mod signature;
pub mod struct_expressions;
mod table_source;
@@ -51,6 +52,7 @@ mod udf;
pub mod utils;
pub mod window_frame;
pub mod window_function;
+pub mod window_state;
pub use accumulator::Accumulator;
pub use aggregate_function::AggregateFunction;
@@ -69,6 +71,7 @@ pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral};
pub use logical_plan::*;
pub use nullif::SUPPORTED_NULLIF_TYPES;
pub use operator::Operator;
+pub use partition_evaluator::PartitionEvaluator;
pub use signature::{Signature, TypeSignature, Volatility};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::AggregateUDF;
diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/expr/src/partition_evaluator.rs
similarity index 96%
rename from datafusion/physical-expr/src/window/partition_evaluator.rs
rename to datafusion/expr/src/partition_evaluator.rs
index e518e89a75..6b159d7105 100644
--- a/datafusion/physical-expr/src/window/partition_evaluator.rs
+++ b/datafusion/expr/src/partition_evaluator.rs
@@ -17,20 +17,21 @@
//! Partition evaluation module
-use crate::window::WindowAggState;
use arrow::array::ArrayRef;
use datafusion_common::Result;
use datafusion_common::{DataFusionError, ScalarValue};
use std::fmt::Debug;
use std::ops::Range;
+use crate::window_state::WindowAggState;
+
/// Partition evaluator for Window Functions
///
/// # Background
///
/// An implementation of this trait is created and used for each
/// partition defined by an `OVER` clause and is instantiated by
-/// [`BuiltInWindowFunctionExpr::create_evaluator`]
+/// the DataFusion runtime.
///
/// For example, evaluating `window_func(val) OVER (PARTITION BY col)`
/// on the following data:
@@ -65,7 +66,8 @@ use std::ops::Range;
/// ```
///
/// Different methods on this trait will be called depending on the
-/// capabilities described by [`BuiltInWindowFunctionExpr`]:
+/// capabilities described by [`Self::supports_bounded_execution`],
+/// [`Self::uses_window_frame`], and [`Self::include_rank`],
///
/// # Stateless `PartitionEvaluator`
///
@@ -95,9 +97,6 @@ use std::ops::Range;
/// |false|true|`evaluate` (optionally can also implement `evaluate_all` for more optimized implementation. However, there will be default implementation that is suboptimal) . If we were to implement `ROW_NUMBER` it will end up in this quadrant. Example `OddRowNumber` showcases this use case|
/// |true|false|`evaluate` (I think as long as `uses_window_frame` is `true`. There is no way for `supports_bounded_execution` to be false). I couldn't come up with any example for this quadrant |
/// |true|true|`evaluate`. If we were to implement `FIRST_VALUE`, it would end up in this quadrant|.
-///
-/// [`BuiltInWindowFunctionExpr`]: crate::window::BuiltInWindowFunctionExpr
-/// [`BuiltInWindowFunctionExpr::create_evaluator`]: crate::window::BuiltInWindowFunctionExpr::create_evaluator
pub trait PartitionEvaluator: Debug + Send {
/// Updates the internal state for window function
///
diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/expr/src/window_state.rs
similarity index 85%
rename from datafusion/physical-expr/src/window/window_frame_state.rs
rename to datafusion/expr/src/window_state.rs
index e23a58a09b..09ed83a5a3 100644
--- a/datafusion/physical-expr/src/window/window_frame_state.rs
+++ b/datafusion/expr/src/window_state.rs
@@ -15,19 +15,100 @@
// specific language governing permissions and limitations
// under the License.
-//! This module provides utilities for window frame index calculations
-//! depending on the window frame mode: RANGE, ROWS, GROUPS.
-
-use arrow::array::ArrayRef;
-use arrow::compute::kernels::sort::SortOptions;
-use datafusion_common::utils::{compare_rows, get_row_at_idx, search_in_slice};
-use datafusion_common::{DataFusionError, Result, ScalarValue};
-use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
-use std::cmp::min;
-use std::collections::VecDeque;
-use std::fmt::Debug;
-use std::ops::Range;
-use std::sync::Arc;
+//! Structures used to hold window function state (for implementing WindowUDFs)
+
+use std::{collections::VecDeque, ops::Range, sync::Arc};
+
+use arrow::{
+ array::ArrayRef,
+ compute::{concat, SortOptions},
+ datatypes::DataType,
+ record_batch::RecordBatch,
+};
+use datafusion_common::{
+ utils::{compare_rows, get_row_at_idx, search_in_slice},
+ DataFusionError, Result, ScalarValue,
+};
+
+use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits};
+
+/// Holds the state of evaluating a window function
+#[derive(Debug)]
+pub struct WindowAggState {
+ /// The range that we calculate the window function
+ pub window_frame_range: Range<usize>,
+ pub window_frame_ctx: Option<WindowFrameContext>,
+ /// The index of the last row that its result is calculated inside the partition record batch buffer.
+ pub last_calculated_index: usize,
+ /// The offset of the deleted row number
+ pub offset_pruned_rows: usize,
+ /// Stores the results calculated by window frame
+ pub out_col: ArrayRef,
+ /// Keeps track of how many rows should be generated to be in sync with input record_batch.
+ // (For each row in the input record batch we need to generate a window result).
+ pub n_row_result_missing: usize,
+ /// flag indicating whether we have received all data for this partition
+ pub is_end: bool,
+}
+
+impl WindowAggState {
+ pub fn prune_state(&mut self, n_prune: usize) {
+ self.window_frame_range = Range {
+ start: self.window_frame_range.start - n_prune,
+ end: self.window_frame_range.end - n_prune,
+ };
+ self.last_calculated_index -= n_prune;
+ self.offset_pruned_rows += n_prune;
+
+ match self.window_frame_ctx.as_mut() {
+ // Rows have no state do nothing
+ Some(WindowFrameContext::Rows(_)) => {}
+ Some(WindowFrameContext::Range { .. }) => {}
+ Some(WindowFrameContext::Groups { state, .. }) => {
+ let mut n_group_to_del = 0;
+ for (_, end_idx) in &state.group_end_indices {
+ if n_prune < *end_idx {
+ break;
+ }
+ n_group_to_del += 1;
+ }
+ state.group_end_indices.drain(0..n_group_to_del);
+ state
+ .group_end_indices
+ .iter_mut()
+ .for_each(|(_, start_idx)| *start_idx -= n_prune);
+ state.current_group_idx -= n_group_to_del;
+ }
+ None => {}
+ };
+ }
+
+ pub fn update(
+ &mut self,
+ out_col: &ArrayRef,
+ partition_batch_state: &PartitionBatchState,
+ ) -> Result<()> {
+ self.last_calculated_index += out_col.len();
+ self.out_col = concat(&[&self.out_col, &out_col])?;
+ self.n_row_result_missing =
+ partition_batch_state.record_batch.num_rows() - self.last_calculated_index;
+ self.is_end = partition_batch_state.is_end;
+ Ok(())
+ }
+
+ pub fn new(out_type: &DataType) -> Result<Self> {
+ let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0);
+ Ok(Self {
+ window_frame_range: Range { start: 0, end: 0 },
+ window_frame_ctx: None,
+ last_calculated_index: 0,
+ offset_pruned_rows: 0,
+ out_col: empty_out_col,
+ n_row_result_missing: 0,
+ is_end: false,
+ })
+ }
+}
/// This object stores the window frame state for use in incremental calculations.
#[derive(Debug)]
@@ -125,7 +206,7 @@ impl WindowFrameContext {
)))
}
WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
- min(idx + n as usize, length)
+ std::cmp::min(idx + n as usize, length)
}
// ERRONEOUS FRAMES
WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => {
@@ -150,7 +231,7 @@ impl WindowFrameContext {
// UNBOUNDED FOLLOWING
WindowFrameBound::Following(ScalarValue::UInt64(None)) => length,
WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
- min(idx + n as usize + 1, length)
+ std::cmp::min(idx + n as usize + 1, length)
}
// ERRONEOUS FRAMES
WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => {
@@ -161,6 +242,17 @@ impl WindowFrameContext {
}
}
+/// State for each unique partition determined according to PARTITION BY column(s)
+#[derive(Debug)]
+pub struct PartitionBatchState {
+ /// The record_batch belonging to current partition
+ pub record_batch: RecordBatch,
+ /// Flag indicating whether we have received all data for this partition
+ pub is_end: bool,
+ /// Number of rows emitted for each partition
+ pub n_out_row: usize,
+}
+
/// This structure encapsulates all the state information we require as we scan
/// ranges of data while processing RANGE frames.
/// Attribute `sort_options` stores the column ordering specified by the ORDER
@@ -510,7 +602,7 @@ impl WindowFrameStateGroups {
Ok(match (SIDE, SEARCH_SIDE) {
// Window frame start:
(true, _) => {
- let group_idx = min(group_idx, self.group_end_indices.len());
+ let group_idx = std::cmp::min(group_idx, self.group_end_indices.len());
if group_idx > 0 {
// Normally, start at the boundary of the previous group.
self.group_end_indices[group_idx - 1].1
@@ -531,7 +623,7 @@ impl WindowFrameStateGroups {
}
// Window frame end, FOLLOWING n
(false, false) => {
- let group_idx = min(
+ let group_idx = std::cmp::min(
self.current_group_idx + delta,
self.group_end_indices.len() - 1,
);
@@ -547,11 +639,10 @@ fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result<boo
#[cfg(test)]
mod tests {
- use crate::window::window_frame_state::WindowFrameStateGroups;
+ use super::*;
+ use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use arrow::array::{ArrayRef, Float64Array};
- use arrow_schema::SortOptions;
use datafusion_common::{Result, ScalarValue};
- use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use std::ops::Range;
use std::sync::Arc;
diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs
index 828bc7218f..a528676c26 100644
--- a/datafusion/physical-expr/src/window/built_in.rs
+++ b/datafusion/physical-expr/src/window/built_in.rs
@@ -21,13 +21,10 @@ use std::any::Any;
use std::ops::Range;
use std::sync::Arc;
-use super::window_frame_state::WindowFrameContext;
use super::BuiltInWindowFunctionExpr;
use super::WindowExpr;
use crate::window::window_expr::WindowFn;
-use crate::window::{
- PartitionBatches, PartitionWindowAggStates, WindowAggState, WindowState,
-};
+use crate::window::{PartitionBatches, PartitionWindowAggStates, WindowState};
use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr};
use arrow::array::{new_empty_array, ArrayRef};
use arrow::compute::SortOptions;
@@ -35,6 +32,8 @@ use arrow::datatypes::Field;
use arrow::record_batch::RecordBatch;
use datafusion_common::utils::evaluate_partition_ranges;
use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::window_state::WindowAggState;
+use datafusion_expr::window_state::WindowFrameContext;
use datafusion_expr::WindowFrame;
/// A window expr that takes the form of a [`BuiltInWindowFunctionExpr`].
diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs
index 432bf78368..73e0658267 100644
--- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs
+++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs
@@ -15,13 +15,13 @@
// specific language governing permissions and limitations
// under the License.
-use super::partition_evaluator::PartitionEvaluator;
use crate::equivalence::OrderingEquivalenceBuilder;
use crate::PhysicalExpr;
use arrow::array::ArrayRef;
use arrow::datatypes::Field;
use arrow::record_batch::RecordBatch;
use datafusion_common::Result;
+use datafusion_expr::PartitionEvaluator;
use std::any::Any;
use std::sync::Arc;
diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs
index 9040165ac9..49ed2a74df 100644
--- a/datafusion/physical-expr/src/window/cume_dist.rs
+++ b/datafusion/physical-expr/src/window/cume_dist.rs
@@ -18,13 +18,13 @@
//! Defines physical expression for `cume_dist` that can evaluated
//! at runtime during query execution
-use crate::window::partition_evaluator::PartitionEvaluator;
use crate::window::BuiltInWindowFunctionExpr;
use crate::PhysicalExpr;
use arrow::array::ArrayRef;
use arrow::array::Float64Array;
use arrow::datatypes::{DataType, Field};
use datafusion_common::Result;
+use datafusion_expr::PartitionEvaluator;
use std::any::Any;
use std::iter;
use std::ops::Range;
diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs
index 24248f989e..637297b4cf 100644
--- a/datafusion/physical-expr/src/window/lead_lag.rs
+++ b/datafusion/physical-expr/src/window/lead_lag.rs
@@ -18,15 +18,16 @@
//! Defines physical expression for `lead` and `lag` that can evaluated
//! at runtime during query execution
-use crate::window::partition_evaluator::PartitionEvaluator;
use crate::window::window_expr::LeadLagState;
-use crate::window::{BuiltInWindowFunctionExpr, WindowAggState};
+use crate::window::BuiltInWindowFunctionExpr;
use crate::PhysicalExpr;
use arrow::array::ArrayRef;
use arrow::compute::cast;
use arrow::datatypes::{DataType, Field};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::window_state::WindowAggState;
+use datafusion_expr::PartitionEvaluator;
use std::any::Any;
use std::cmp::min;
use std::ops::{Neg, Range};
diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs
index 4c8b8b5a4e..b1234d599a 100644
--- a/datafusion/physical-expr/src/window/mod.rs
+++ b/datafusion/physical-expr/src/window/mod.rs
@@ -22,21 +22,17 @@ pub(crate) mod cume_dist;
pub(crate) mod lead_lag;
pub(crate) mod nth_value;
pub(crate) mod ntile;
-pub(crate) mod partition_evaluator;
pub(crate) mod rank;
pub(crate) mod row_number;
mod sliding_aggregate;
mod window_expr;
-mod window_frame_state;
pub use aggregate::PlainAggregateWindowExpr;
pub use built_in::BuiltInWindowExpr;
pub use built_in_window_function_expr::BuiltInWindowFunctionExpr;
pub use sliding_aggregate::SlidingAggregateWindowExpr;
-pub use window_expr::PartitionBatchState;
pub use window_expr::PartitionBatches;
pub use window_expr::PartitionKey;
pub use window_expr::PartitionWindowAggStates;
-pub use window_expr::WindowAggState;
pub use window_expr::WindowExpr;
pub use window_expr::WindowState;
diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs
index e6dbeba834..2d592bbb6f 100644
--- a/datafusion/physical-expr/src/window/nth_value.rs
+++ b/datafusion/physical-expr/src/window/nth_value.rs
@@ -18,14 +18,15 @@
//! Defines physical expressions for `first_value`, `last_value`, and `nth_value`
//! that can evaluated at runtime during query execution
-use crate::window::partition_evaluator::PartitionEvaluator;
use crate::window::window_expr::{NthValueKind, NthValueState};
-use crate::window::{BuiltInWindowFunctionExpr, WindowAggState};
+use crate::window::BuiltInWindowFunctionExpr;
use crate::PhysicalExpr;
use arrow::array::{Array, ArrayRef};
use arrow::datatypes::{DataType, Field};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::window_state::WindowAggState;
+use datafusion_expr::PartitionEvaluator;
use std::any::Any;
use std::ops::Range;
use std::sync::Arc;
diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs
index 2feab9956a..6019ffbeef 100644
--- a/datafusion/physical-expr/src/window/ntile.rs
+++ b/datafusion/physical-expr/src/window/ntile.rs
@@ -18,13 +18,13 @@
//! Defines physical expression for `ntile` that can evaluated
//! at runtime during query execution
-use crate::window::partition_evaluator::PartitionEvaluator;
use crate::window::BuiltInWindowFunctionExpr;
use crate::PhysicalExpr;
use arrow::array::{ArrayRef, UInt64Array};
use arrow::datatypes::Field;
use arrow_schema::DataType;
use datafusion_common::Result;
+use datafusion_expr::PartitionEvaluator;
use std::any::Any;
use std::sync::Arc;
diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs
index 59a08358cd..527eaab611 100644
--- a/datafusion/physical-expr/src/window/rank.rs
+++ b/datafusion/physical-expr/src/window/rank.rs
@@ -18,15 +18,16 @@
//! Defines physical expression for `rank`, `dense_rank`, and `percent_rank` that can evaluated
//! at runtime during query execution
-use crate::window::partition_evaluator::PartitionEvaluator;
use crate::window::window_expr::RankState;
-use crate::window::{BuiltInWindowFunctionExpr, WindowAggState};
+use crate::window::BuiltInWindowFunctionExpr;
use crate::PhysicalExpr;
use arrow::array::ArrayRef;
use arrow::array::{Float64Array, UInt64Array};
use arrow::datatypes::{DataType, Field};
use datafusion_common::utils::get_row_at_idx;
use datafusion_common::{DataFusionError, Result, ScalarValue};
+use datafusion_expr::window_state::WindowAggState;
+use datafusion_expr::PartitionEvaluator;
use std::any::Any;
use std::iter;
use std::ops::Range;
diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs
index 3c1f0583b4..b115e9f149 100644
--- a/datafusion/physical-expr/src/window/row_number.rs
+++ b/datafusion/physical-expr/src/window/row_number.rs
@@ -19,7 +19,6 @@
use crate::equivalence::OrderingEquivalenceBuilder;
use crate::expressions::Column;
-use crate::window::partition_evaluator::PartitionEvaluator;
use crate::window::window_expr::NumRowsState;
use crate::window::BuiltInWindowFunctionExpr;
use crate::{PhysicalExpr, PhysicalSortExpr};
@@ -27,6 +26,7 @@ use arrow::array::{ArrayRef, UInt64Array};
use arrow::datatypes::{DataType, Field};
use arrow_schema::SortOptions;
use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::PartitionEvaluator;
use std::any::Any;
use std::ops::Range;
use std::sync::Arc;
diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs
index dbb21b1f3e..9175d97525 100644
--- a/datafusion/physical-expr/src/window/window_expr.rs
+++ b/datafusion/physical-expr/src/window/window_expr.rs
@@ -15,16 +15,17 @@
// specific language governing permissions and limitations
// under the License.
-use crate::window::partition_evaluator::PartitionEvaluator;
-use crate::window::window_frame_state::WindowFrameContext;
use crate::{PhysicalExpr, PhysicalSortExpr};
use arrow::array::{new_empty_array, Array, ArrayRef};
use arrow::compute::kernels::sort::SortColumn;
-use arrow::compute::{concat, SortOptions};
+use arrow::compute::SortOptions;
use arrow::datatypes::Field;
use arrow::record_batch::RecordBatch;
-use arrow_schema::DataType;
use datafusion_common::{DataFusionError, Result, ScalarValue};
+use datafusion_expr::window_state::{
+ PartitionBatchState, WindowAggState, WindowFrameContext,
+};
+use datafusion_expr::PartitionEvaluator;
use datafusion_expr::{Accumulator, WindowFrame};
use indexmap::IndexMap;
use std::any::Any;
@@ -327,84 +328,6 @@ pub struct LeadLagState {
pub idx: usize,
}
-/// Holds the state of evaluating a window function
-#[derive(Debug)]
-pub struct WindowAggState {
- /// The range that we calculate the window function
- pub window_frame_range: Range<usize>,
- pub window_frame_ctx: Option<WindowFrameContext>,
- /// The index of the last row that its result is calculated inside the partition record batch buffer.
- pub last_calculated_index: usize,
- /// The offset of the deleted row number
- pub offset_pruned_rows: usize,
- /// Stores the results calculated by window frame
- pub out_col: ArrayRef,
- /// Keeps track of how many rows should be generated to be in sync with input record_batch.
- // (For each row in the input record batch we need to generate a window result).
- pub n_row_result_missing: usize,
- /// flag indicating whether we have received all data for this partition
- pub is_end: bool,
-}
-
-impl WindowAggState {
- pub fn prune_state(&mut self, n_prune: usize) {
- self.window_frame_range = Range {
- start: self.window_frame_range.start - n_prune,
- end: self.window_frame_range.end - n_prune,
- };
- self.last_calculated_index -= n_prune;
- self.offset_pruned_rows += n_prune;
-
- match self.window_frame_ctx.as_mut() {
- // Rows have no state do nothing
- Some(WindowFrameContext::Rows(_)) => {}
- Some(WindowFrameContext::Range { .. }) => {}
- Some(WindowFrameContext::Groups { state, .. }) => {
- let mut n_group_to_del = 0;
- for (_, end_idx) in &state.group_end_indices {
- if n_prune < *end_idx {
- break;
- }
- n_group_to_del += 1;
- }
- state.group_end_indices.drain(0..n_group_to_del);
- state
- .group_end_indices
- .iter_mut()
- .for_each(|(_, start_idx)| *start_idx -= n_prune);
- state.current_group_idx -= n_group_to_del;
- }
- None => {}
- };
- }
-}
-
-impl WindowAggState {
- pub fn update(
- &mut self,
- out_col: &ArrayRef,
- partition_batch_state: &PartitionBatchState,
- ) -> Result<()> {
- self.last_calculated_index += out_col.len();
- self.out_col = concat(&[&self.out_col, &out_col])?;
- self.n_row_result_missing =
- partition_batch_state.record_batch.num_rows() - self.last_calculated_index;
- self.is_end = partition_batch_state.is_end;
- Ok(())
- }
-}
-
-/// State for each unique partition determined according to PARTITION BY column(s)
-#[derive(Debug)]
-pub struct PartitionBatchState {
- /// The record_batch belonging to current partition
- pub record_batch: RecordBatch,
- /// Flag indicating whether we have received all data for this partition
- pub is_end: bool,
- /// Number of rows emitted for each partition
- pub n_out_row: usize,
-}
-
/// Key for IndexMap for each unique partition
///
/// For instance, if window frame is `OVER(PARTITION BY a,b)`,
@@ -420,18 +343,3 @@ pub type PartitionWindowAggStates = IndexMap<PartitionKey, WindowState>;
/// The IndexMap (i.e. an ordered HashMap) where record batches are separated for each partition.
pub type PartitionBatches = IndexMap<PartitionKey, PartitionBatchState>;
-
-impl WindowAggState {
- pub fn new(out_type: &DataType) -> Result<Self> {
- let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0);
- Ok(Self {
- window_frame_range: Range { start: 0, end: 0 },
- window_frame_ctx: None,
- last_calculated_index: 0,
- offset_pruned_rows: 0,
- out_col: empty_out_col,
- n_row_result_missing: 0,
- is_end: false,
- })
- }
-}