You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ak...@apache.org on 2023/06/21 06:56:12 UTC

[arrow-datafusion] branch main updated: Move `PartitionEvaluator` and window_state structures to `datafusion_expr` crate (#6690)

This is an automated email from the ASF dual-hosted git repository.

akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 3c304e0519 Move `PartitionEvaluator` and window_state structures to `datafusion_expr` crate (#6690)
3c304e0519 is described below

commit 3c304e05194c4e14e020e27a0d00578ed3a749fb
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Wed Jun 21 02:56:05 2023 -0400

    Move `PartitionEvaluator` and window_state structures to `datafusion_expr` crate (#6690)
    
    * Move `PartitonEvaluator` and window_state structures to `datafusion_expr` crate
    
    * Update docs
---
 .../windows/bounded_window_agg_exec.rs             |   4 +-
 datafusion/expr/src/lib.rs                         |   3 +
 .../src/window => expr/src}/partition_evaluator.rs |  11 +-
 .../src/window_state.rs}                           | 131 +++++++++++++++++----
 datafusion/physical-expr/src/window/built_in.rs    |   7 +-
 .../src/window/built_in_window_function_expr.rs    |   2 +-
 datafusion/physical-expr/src/window/cume_dist.rs   |   2 +-
 datafusion/physical-expr/src/window/lead_lag.rs    |   5 +-
 datafusion/physical-expr/src/window/mod.rs         |   4 -
 datafusion/physical-expr/src/window/nth_value.rs   |   5 +-
 datafusion/physical-expr/src/window/ntile.rs       |   2 +-
 datafusion/physical-expr/src/window/rank.rs        |   5 +-
 datafusion/physical-expr/src/window/row_number.rs  |   2 +-
 datafusion/physical-expr/src/window/window_expr.rs | 102 +---------------
 14 files changed, 142 insertions(+), 143 deletions(-)

diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs
index 2512776e8d..9c86abec81 100644
--- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs
+++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs
@@ -41,6 +41,7 @@ use arrow::{
     datatypes::{Schema, SchemaBuilder, SchemaRef},
     record_batch::RecordBatch,
 };
+use datafusion_expr::window_state::{PartitionBatchState, WindowAggState};
 use futures::stream::Stream;
 use futures::{ready, StreamExt};
 use hashbrown::raw::RawTable;
@@ -62,8 +63,7 @@ use datafusion_common::DataFusionError;
 use datafusion_expr::ColumnarValue;
 use datafusion_physical_expr::hash_utils::create_hashes;
 use datafusion_physical_expr::window::{
-    PartitionBatchState, PartitionBatches, PartitionKey, PartitionWindowAggStates,
-    WindowAggState, WindowState,
+    PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowState,
 };
 use datafusion_physical_expr::{
     EquivalenceProperties, OrderingEquivalenceProperties, PhysicalExpr,
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index 1675afb9c9..ccb9728877 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -41,6 +41,7 @@ mod literal;
 pub mod logical_plan;
 mod nullif;
 mod operator;
+mod partition_evaluator;
 mod signature;
 pub mod struct_expressions;
 mod table_source;
@@ -51,6 +52,7 @@ mod udf;
 pub mod utils;
 pub mod window_frame;
 pub mod window_function;
+pub mod window_state;
 
 pub use accumulator::Accumulator;
 pub use aggregate_function::AggregateFunction;
@@ -69,6 +71,7 @@ pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral};
 pub use logical_plan::*;
 pub use nullif::SUPPORTED_NULLIF_TYPES;
 pub use operator::Operator;
+pub use partition_evaluator::PartitionEvaluator;
 pub use signature::{Signature, TypeSignature, Volatility};
 pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
 pub use udaf::AggregateUDF;
diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/expr/src/partition_evaluator.rs
similarity index 96%
rename from datafusion/physical-expr/src/window/partition_evaluator.rs
rename to datafusion/expr/src/partition_evaluator.rs
index e518e89a75..6b159d7105 100644
--- a/datafusion/physical-expr/src/window/partition_evaluator.rs
+++ b/datafusion/expr/src/partition_evaluator.rs
@@ -17,20 +17,21 @@
 
 //! Partition evaluation module
 
-use crate::window::WindowAggState;
 use arrow::array::ArrayRef;
 use datafusion_common::Result;
 use datafusion_common::{DataFusionError, ScalarValue};
 use std::fmt::Debug;
 use std::ops::Range;
 
+use crate::window_state::WindowAggState;
+
 /// Partition evaluator for Window Functions
 ///
 /// # Background
 ///
 /// An implementation of this trait is created and used for each
 /// partition defined by an `OVER` clause and is instantiated by
-/// [`BuiltInWindowFunctionExpr::create_evaluator`]
+/// the DataFusion runtime.
 ///
 /// For example, evaluating `window_func(val) OVER (PARTITION BY col)`
 /// on the following data:
@@ -65,7 +66,8 @@ use std::ops::Range;
 /// ```
 ///
 /// Different methods on this trait will be called depending on the
-/// capabilities described by [`BuiltInWindowFunctionExpr`]:
+/// capabilities described by [`Self::supports_bounded_execution`],
+/// [`Self::uses_window_frame`], and [`Self::include_rank`],
 ///
 /// # Stateless `PartitionEvaluator`
 ///
@@ -95,9 +97,6 @@ use std::ops::Range;
 /// |false|true|`evaluate` (optionally can also implement `evaluate_all` for more optimized implementation. However, there will be default implementation that is suboptimal) . If we were to implement `ROW_NUMBER` it will end up in this quadrant. Example `OddRowNumber` showcases this use case|
 /// |true|false|`evaluate` (I think as long as `uses_window_frame` is `true`. There is no way for `supports_bounded_execution` to be false). I couldn't come up with any example for this quadrant |
 /// |true|true|`evaluate`. If we were to implement `FIRST_VALUE`, it would end up in this quadrant|.
-///
-/// [`BuiltInWindowFunctionExpr`]: crate::window::BuiltInWindowFunctionExpr
-/// [`BuiltInWindowFunctionExpr::create_evaluator`]: crate::window::BuiltInWindowFunctionExpr::create_evaluator
 pub trait PartitionEvaluator: Debug + Send {
     /// Updates the internal state for window function
     ///
diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/expr/src/window_state.rs
similarity index 85%
rename from datafusion/physical-expr/src/window/window_frame_state.rs
rename to datafusion/expr/src/window_state.rs
index e23a58a09b..09ed83a5a3 100644
--- a/datafusion/physical-expr/src/window/window_frame_state.rs
+++ b/datafusion/expr/src/window_state.rs
@@ -15,19 +15,100 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! This module provides utilities for window frame index calculations
-//! depending on the window frame mode: RANGE, ROWS, GROUPS.
-
-use arrow::array::ArrayRef;
-use arrow::compute::kernels::sort::SortOptions;
-use datafusion_common::utils::{compare_rows, get_row_at_idx, search_in_slice};
-use datafusion_common::{DataFusionError, Result, ScalarValue};
-use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
-use std::cmp::min;
-use std::collections::VecDeque;
-use std::fmt::Debug;
-use std::ops::Range;
-use std::sync::Arc;
+//! Structures used to hold window function state (for implementing WindowUDFs)
+
+use std::{collections::VecDeque, ops::Range, sync::Arc};
+
+use arrow::{
+    array::ArrayRef,
+    compute::{concat, SortOptions},
+    datatypes::DataType,
+    record_batch::RecordBatch,
+};
+use datafusion_common::{
+    utils::{compare_rows, get_row_at_idx, search_in_slice},
+    DataFusionError, Result, ScalarValue,
+};
+
+use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits};
+
+/// Holds the state of evaluating a window function
+#[derive(Debug)]
+pub struct WindowAggState {
+    /// The range that we calculate the window function
+    pub window_frame_range: Range<usize>,
+    pub window_frame_ctx: Option<WindowFrameContext>,
+    /// The index of the last row that its result is calculated inside the partition record batch buffer.
+    pub last_calculated_index: usize,
+    /// The offset of the deleted row number
+    pub offset_pruned_rows: usize,
+    /// Stores the results calculated by window frame
+    pub out_col: ArrayRef,
+    /// Keeps track of how many rows should be generated to be in sync with input record_batch.
+    // (For each row in the input record batch we need to generate a window result).
+    pub n_row_result_missing: usize,
+    /// flag indicating whether we have received all data for this partition
+    pub is_end: bool,
+}
+
+impl WindowAggState {
+    pub fn prune_state(&mut self, n_prune: usize) {
+        self.window_frame_range = Range {
+            start: self.window_frame_range.start - n_prune,
+            end: self.window_frame_range.end - n_prune,
+        };
+        self.last_calculated_index -= n_prune;
+        self.offset_pruned_rows += n_prune;
+
+        match self.window_frame_ctx.as_mut() {
+            // Rows have no state do nothing
+            Some(WindowFrameContext::Rows(_)) => {}
+            Some(WindowFrameContext::Range { .. }) => {}
+            Some(WindowFrameContext::Groups { state, .. }) => {
+                let mut n_group_to_del = 0;
+                for (_, end_idx) in &state.group_end_indices {
+                    if n_prune < *end_idx {
+                        break;
+                    }
+                    n_group_to_del += 1;
+                }
+                state.group_end_indices.drain(0..n_group_to_del);
+                state
+                    .group_end_indices
+                    .iter_mut()
+                    .for_each(|(_, start_idx)| *start_idx -= n_prune);
+                state.current_group_idx -= n_group_to_del;
+            }
+            None => {}
+        };
+    }
+
+    pub fn update(
+        &mut self,
+        out_col: &ArrayRef,
+        partition_batch_state: &PartitionBatchState,
+    ) -> Result<()> {
+        self.last_calculated_index += out_col.len();
+        self.out_col = concat(&[&self.out_col, &out_col])?;
+        self.n_row_result_missing =
+            partition_batch_state.record_batch.num_rows() - self.last_calculated_index;
+        self.is_end = partition_batch_state.is_end;
+        Ok(())
+    }
+
+    pub fn new(out_type: &DataType) -> Result<Self> {
+        let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0);
+        Ok(Self {
+            window_frame_range: Range { start: 0, end: 0 },
+            window_frame_ctx: None,
+            last_calculated_index: 0,
+            offset_pruned_rows: 0,
+            out_col: empty_out_col,
+            n_row_result_missing: 0,
+            is_end: false,
+        })
+    }
+}
 
 /// This object stores the window frame state for use in incremental calculations.
 #[derive(Debug)]
@@ -125,7 +206,7 @@ impl WindowFrameContext {
                 )))
             }
             WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
-                min(idx + n as usize, length)
+                std::cmp::min(idx + n as usize, length)
             }
             // ERRONEOUS FRAMES
             WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => {
@@ -150,7 +231,7 @@ impl WindowFrameContext {
             // UNBOUNDED FOLLOWING
             WindowFrameBound::Following(ScalarValue::UInt64(None)) => length,
             WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
-                min(idx + n as usize + 1, length)
+                std::cmp::min(idx + n as usize + 1, length)
             }
             // ERRONEOUS FRAMES
             WindowFrameBound::Preceding(_) | WindowFrameBound::Following(_) => {
@@ -161,6 +242,17 @@ impl WindowFrameContext {
     }
 }
 
+/// State for each unique partition determined according to PARTITION BY column(s)
+#[derive(Debug)]
+pub struct PartitionBatchState {
+    /// The record_batch belonging to current partition
+    pub record_batch: RecordBatch,
+    /// Flag indicating whether we have received all data for this partition
+    pub is_end: bool,
+    /// Number of rows emitted for each partition
+    pub n_out_row: usize,
+}
+
 /// This structure encapsulates all the state information we require as we scan
 /// ranges of data while processing RANGE frames.
 /// Attribute `sort_options` stores the column ordering specified by the ORDER
@@ -510,7 +602,7 @@ impl WindowFrameStateGroups {
         Ok(match (SIDE, SEARCH_SIDE) {
             // Window frame start:
             (true, _) => {
-                let group_idx = min(group_idx, self.group_end_indices.len());
+                let group_idx = std::cmp::min(group_idx, self.group_end_indices.len());
                 if group_idx > 0 {
                     // Normally, start at the boundary of the previous group.
                     self.group_end_indices[group_idx - 1].1
@@ -531,7 +623,7 @@ impl WindowFrameStateGroups {
             }
             // Window frame end, FOLLOWING n
             (false, false) => {
-                let group_idx = min(
+                let group_idx = std::cmp::min(
                     self.current_group_idx + delta,
                     self.group_end_indices.len() - 1,
                 );
@@ -547,11 +639,10 @@ fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result<boo
 
 #[cfg(test)]
 mod tests {
-    use crate::window::window_frame_state::WindowFrameStateGroups;
+    use super::*;
+    use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits};
     use arrow::array::{ArrayRef, Float64Array};
-    use arrow_schema::SortOptions;
     use datafusion_common::{Result, ScalarValue};
-    use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
     use std::ops::Range;
     use std::sync::Arc;
 
diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs
index 828bc7218f..a528676c26 100644
--- a/datafusion/physical-expr/src/window/built_in.rs
+++ b/datafusion/physical-expr/src/window/built_in.rs
@@ -21,13 +21,10 @@ use std::any::Any;
 use std::ops::Range;
 use std::sync::Arc;
 
-use super::window_frame_state::WindowFrameContext;
 use super::BuiltInWindowFunctionExpr;
 use super::WindowExpr;
 use crate::window::window_expr::WindowFn;
-use crate::window::{
-    PartitionBatches, PartitionWindowAggStates, WindowAggState, WindowState,
-};
+use crate::window::{PartitionBatches, PartitionWindowAggStates, WindowState};
 use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr};
 use arrow::array::{new_empty_array, ArrayRef};
 use arrow::compute::SortOptions;
@@ -35,6 +32,8 @@ use arrow::datatypes::Field;
 use arrow::record_batch::RecordBatch;
 use datafusion_common::utils::evaluate_partition_ranges;
 use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::window_state::WindowAggState;
+use datafusion_expr::window_state::WindowFrameContext;
 use datafusion_expr::WindowFrame;
 
 /// A window expr that takes the form of a [`BuiltInWindowFunctionExpr`].
diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs
index 432bf78368..73e0658267 100644
--- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs
+++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs
@@ -15,13 +15,13 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use super::partition_evaluator::PartitionEvaluator;
 use crate::equivalence::OrderingEquivalenceBuilder;
 use crate::PhysicalExpr;
 use arrow::array::ArrayRef;
 use arrow::datatypes::Field;
 use arrow::record_batch::RecordBatch;
 use datafusion_common::Result;
+use datafusion_expr::PartitionEvaluator;
 use std::any::Any;
 use std::sync::Arc;
 
diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs
index 9040165ac9..49ed2a74df 100644
--- a/datafusion/physical-expr/src/window/cume_dist.rs
+++ b/datafusion/physical-expr/src/window/cume_dist.rs
@@ -18,13 +18,13 @@
 //! Defines physical expression for `cume_dist` that can evaluated
 //! at runtime during query execution
 
-use crate::window::partition_evaluator::PartitionEvaluator;
 use crate::window::BuiltInWindowFunctionExpr;
 use crate::PhysicalExpr;
 use arrow::array::ArrayRef;
 use arrow::array::Float64Array;
 use arrow::datatypes::{DataType, Field};
 use datafusion_common::Result;
+use datafusion_expr::PartitionEvaluator;
 use std::any::Any;
 use std::iter;
 use std::ops::Range;
diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs
index 24248f989e..637297b4cf 100644
--- a/datafusion/physical-expr/src/window/lead_lag.rs
+++ b/datafusion/physical-expr/src/window/lead_lag.rs
@@ -18,15 +18,16 @@
 //! Defines physical expression for `lead` and `lag` that can evaluated
 //! at runtime during query execution
 
-use crate::window::partition_evaluator::PartitionEvaluator;
 use crate::window::window_expr::LeadLagState;
-use crate::window::{BuiltInWindowFunctionExpr, WindowAggState};
+use crate::window::BuiltInWindowFunctionExpr;
 use crate::PhysicalExpr;
 use arrow::array::ArrayRef;
 use arrow::compute::cast;
 use arrow::datatypes::{DataType, Field};
 use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::window_state::WindowAggState;
+use datafusion_expr::PartitionEvaluator;
 use std::any::Any;
 use std::cmp::min;
 use std::ops::{Neg, Range};
diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs
index 4c8b8b5a4e..b1234d599a 100644
--- a/datafusion/physical-expr/src/window/mod.rs
+++ b/datafusion/physical-expr/src/window/mod.rs
@@ -22,21 +22,17 @@ pub(crate) mod cume_dist;
 pub(crate) mod lead_lag;
 pub(crate) mod nth_value;
 pub(crate) mod ntile;
-pub(crate) mod partition_evaluator;
 pub(crate) mod rank;
 pub(crate) mod row_number;
 mod sliding_aggregate;
 mod window_expr;
-mod window_frame_state;
 
 pub use aggregate::PlainAggregateWindowExpr;
 pub use built_in::BuiltInWindowExpr;
 pub use built_in_window_function_expr::BuiltInWindowFunctionExpr;
 pub use sliding_aggregate::SlidingAggregateWindowExpr;
-pub use window_expr::PartitionBatchState;
 pub use window_expr::PartitionBatches;
 pub use window_expr::PartitionKey;
 pub use window_expr::PartitionWindowAggStates;
-pub use window_expr::WindowAggState;
 pub use window_expr::WindowExpr;
 pub use window_expr::WindowState;
diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs
index e6dbeba834..2d592bbb6f 100644
--- a/datafusion/physical-expr/src/window/nth_value.rs
+++ b/datafusion/physical-expr/src/window/nth_value.rs
@@ -18,14 +18,15 @@
 //! Defines physical expressions for `first_value`, `last_value`, and `nth_value`
 //! that can evaluated at runtime during query execution
 
-use crate::window::partition_evaluator::PartitionEvaluator;
 use crate::window::window_expr::{NthValueKind, NthValueState};
-use crate::window::{BuiltInWindowFunctionExpr, WindowAggState};
+use crate::window::BuiltInWindowFunctionExpr;
 use crate::PhysicalExpr;
 use arrow::array::{Array, ArrayRef};
 use arrow::datatypes::{DataType, Field};
 use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::window_state::WindowAggState;
+use datafusion_expr::PartitionEvaluator;
 use std::any::Any;
 use std::ops::Range;
 use std::sync::Arc;
diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs
index 2feab9956a..6019ffbeef 100644
--- a/datafusion/physical-expr/src/window/ntile.rs
+++ b/datafusion/physical-expr/src/window/ntile.rs
@@ -18,13 +18,13 @@
 //! Defines physical expression for `ntile` that can evaluated
 //! at runtime during query execution
 
-use crate::window::partition_evaluator::PartitionEvaluator;
 use crate::window::BuiltInWindowFunctionExpr;
 use crate::PhysicalExpr;
 use arrow::array::{ArrayRef, UInt64Array};
 use arrow::datatypes::Field;
 use arrow_schema::DataType;
 use datafusion_common::Result;
+use datafusion_expr::PartitionEvaluator;
 use std::any::Any;
 use std::sync::Arc;
 
diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs
index 59a08358cd..527eaab611 100644
--- a/datafusion/physical-expr/src/window/rank.rs
+++ b/datafusion/physical-expr/src/window/rank.rs
@@ -18,15 +18,16 @@
 //! Defines physical expression for `rank`, `dense_rank`, and `percent_rank` that can evaluated
 //! at runtime during query execution
 
-use crate::window::partition_evaluator::PartitionEvaluator;
 use crate::window::window_expr::RankState;
-use crate::window::{BuiltInWindowFunctionExpr, WindowAggState};
+use crate::window::BuiltInWindowFunctionExpr;
 use crate::PhysicalExpr;
 use arrow::array::ArrayRef;
 use arrow::array::{Float64Array, UInt64Array};
 use arrow::datatypes::{DataType, Field};
 use datafusion_common::utils::get_row_at_idx;
 use datafusion_common::{DataFusionError, Result, ScalarValue};
+use datafusion_expr::window_state::WindowAggState;
+use datafusion_expr::PartitionEvaluator;
 use std::any::Any;
 use std::iter;
 use std::ops::Range;
diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs
index 3c1f0583b4..b115e9f149 100644
--- a/datafusion/physical-expr/src/window/row_number.rs
+++ b/datafusion/physical-expr/src/window/row_number.rs
@@ -19,7 +19,6 @@
 
 use crate::equivalence::OrderingEquivalenceBuilder;
 use crate::expressions::Column;
-use crate::window::partition_evaluator::PartitionEvaluator;
 use crate::window::window_expr::NumRowsState;
 use crate::window::BuiltInWindowFunctionExpr;
 use crate::{PhysicalExpr, PhysicalSortExpr};
@@ -27,6 +26,7 @@ use arrow::array::{ArrayRef, UInt64Array};
 use arrow::datatypes::{DataType, Field};
 use arrow_schema::SortOptions;
 use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::PartitionEvaluator;
 use std::any::Any;
 use std::ops::Range;
 use std::sync::Arc;
diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs
index dbb21b1f3e..9175d97525 100644
--- a/datafusion/physical-expr/src/window/window_expr.rs
+++ b/datafusion/physical-expr/src/window/window_expr.rs
@@ -15,16 +15,17 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::window::partition_evaluator::PartitionEvaluator;
-use crate::window::window_frame_state::WindowFrameContext;
 use crate::{PhysicalExpr, PhysicalSortExpr};
 use arrow::array::{new_empty_array, Array, ArrayRef};
 use arrow::compute::kernels::sort::SortColumn;
-use arrow::compute::{concat, SortOptions};
+use arrow::compute::SortOptions;
 use arrow::datatypes::Field;
 use arrow::record_batch::RecordBatch;
-use arrow_schema::DataType;
 use datafusion_common::{DataFusionError, Result, ScalarValue};
+use datafusion_expr::window_state::{
+    PartitionBatchState, WindowAggState, WindowFrameContext,
+};
+use datafusion_expr::PartitionEvaluator;
 use datafusion_expr::{Accumulator, WindowFrame};
 use indexmap::IndexMap;
 use std::any::Any;
@@ -327,84 +328,6 @@ pub struct LeadLagState {
     pub idx: usize,
 }
 
-/// Holds the state of evaluating a window function
-#[derive(Debug)]
-pub struct WindowAggState {
-    /// The range that we calculate the window function
-    pub window_frame_range: Range<usize>,
-    pub window_frame_ctx: Option<WindowFrameContext>,
-    /// The index of the last row that its result is calculated inside the partition record batch buffer.
-    pub last_calculated_index: usize,
-    /// The offset of the deleted row number
-    pub offset_pruned_rows: usize,
-    /// Stores the results calculated by window frame
-    pub out_col: ArrayRef,
-    /// Keeps track of how many rows should be generated to be in sync with input record_batch.
-    // (For each row in the input record batch we need to generate a window result).
-    pub n_row_result_missing: usize,
-    /// flag indicating whether we have received all data for this partition
-    pub is_end: bool,
-}
-
-impl WindowAggState {
-    pub fn prune_state(&mut self, n_prune: usize) {
-        self.window_frame_range = Range {
-            start: self.window_frame_range.start - n_prune,
-            end: self.window_frame_range.end - n_prune,
-        };
-        self.last_calculated_index -= n_prune;
-        self.offset_pruned_rows += n_prune;
-
-        match self.window_frame_ctx.as_mut() {
-            // Rows have no state do nothing
-            Some(WindowFrameContext::Rows(_)) => {}
-            Some(WindowFrameContext::Range { .. }) => {}
-            Some(WindowFrameContext::Groups { state, .. }) => {
-                let mut n_group_to_del = 0;
-                for (_, end_idx) in &state.group_end_indices {
-                    if n_prune < *end_idx {
-                        break;
-                    }
-                    n_group_to_del += 1;
-                }
-                state.group_end_indices.drain(0..n_group_to_del);
-                state
-                    .group_end_indices
-                    .iter_mut()
-                    .for_each(|(_, start_idx)| *start_idx -= n_prune);
-                state.current_group_idx -= n_group_to_del;
-            }
-            None => {}
-        };
-    }
-}
-
-impl WindowAggState {
-    pub fn update(
-        &mut self,
-        out_col: &ArrayRef,
-        partition_batch_state: &PartitionBatchState,
-    ) -> Result<()> {
-        self.last_calculated_index += out_col.len();
-        self.out_col = concat(&[&self.out_col, &out_col])?;
-        self.n_row_result_missing =
-            partition_batch_state.record_batch.num_rows() - self.last_calculated_index;
-        self.is_end = partition_batch_state.is_end;
-        Ok(())
-    }
-}
-
-/// State for each unique partition determined according to PARTITION BY column(s)
-#[derive(Debug)]
-pub struct PartitionBatchState {
-    /// The record_batch belonging to current partition
-    pub record_batch: RecordBatch,
-    /// Flag indicating whether we have received all data for this partition
-    pub is_end: bool,
-    /// Number of rows emitted for each partition
-    pub n_out_row: usize,
-}
-
 /// Key for IndexMap for each unique partition
 ///
 /// For instance, if window frame is `OVER(PARTITION BY a,b)`,
@@ -420,18 +343,3 @@ pub type PartitionWindowAggStates = IndexMap<PartitionKey, WindowState>;
 
 /// The IndexMap (i.e. an ordered HashMap) where record batches are separated for each partition.
 pub type PartitionBatches = IndexMap<PartitionKey, PartitionBatchState>;
-
-impl WindowAggState {
-    pub fn new(out_type: &DataType) -> Result<Self> {
-        let empty_out_col = ScalarValue::try_from(out_type)?.to_array_of_size(0);
-        Ok(Self {
-            window_frame_range: Range { start: 0, end: 0 },
-            window_frame_ctx: None,
-            last_calculated_index: 0,
-            offset_pruned_rows: 0,
-            out_col: empty_out_col,
-            n_row_result_missing: 0,
-            is_end: false,
-        })
-    }
-}