You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/26 21:17:38 UTC
[arrow-datafusion] branch master updated: Unnecessary SortExec removal rule from Physical Plan (#4691)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 8ec511e51 Unnecessary SortExec removal rule from Physical Plan (#4691)
8ec511e51 is described below
commit 8ec511e51cdfdef8e3f79116076fbb962f53f887
Author: Mustafa akur <10...@users.noreply.github.com>
AuthorDate: Tue Dec 27 00:17:32 2022 +0300
Unnecessary SortExec removal rule from Physical Plan (#4691)
* Sort Removal rule initial commit
* move ordering satisfy to the util
* update test and change repartition maintain_input_order impl
* simplifications
* partition by refactor (#28)
* partition by refactor
* minor changes
* Unnecessary tuple to Range conversion is removed
* move transpose under common
* Add naive sort removal rule
* Add todo for finer Sort removal handling
* Refactors to improve readability and reduce nesting
* reverse expr returns Option (no need for support check)
* fix tests
* partition by and order by no longer ends up at the same window group
* Refactor to simplify code
* Better comments, change method names
* Resolve errors introduced by syncing
* address reviews
* address reviews
* Rename to less confusing OptimizeSorts
Co-authored-by: Mehmet Ozan Kabak <oz...@gmail.com>
---
datafusion/common/src/lib.rs | 10 +
datafusion/core/src/execution/context.rs | 7 +
.../core/src/physical_optimizer/enforcement.rs | 76 +-
datafusion/core/src/physical_optimizer/mod.rs | 1 +
.../core/src/physical_optimizer/optimize_sorts.rs | 887 +++++++++++++++++++++
datafusion/core/src/physical_optimizer/utils.rs | 75 ++
datafusion/core/src/physical_plan/common.rs | 25 +
datafusion/core/src/physical_plan/planner.rs | 2 +-
datafusion/core/src/physical_plan/repartition.rs | 11 +-
.../src/physical_plan/windows/window_agg_exec.rs | 81 +-
datafusion/core/tests/sql/explain_analyze.rs | 5 -
datafusion/core/tests/sql/window.rs | 582 +++++++++++++-
datafusion/expr/src/logical_plan/builder.rs | 2 +-
datafusion/expr/src/utils.rs | 62 +-
datafusion/expr/src/window_frame.rs | 29 +
datafusion/physical-expr/src/aggregate/count.rs | 6 +-
datafusion/physical-expr/src/aggregate/mod.rs | 8 +
datafusion/physical-expr/src/aggregate/sum.rs | 6 +-
datafusion/physical-expr/src/window/aggregate.rs | 94 ++-
datafusion/physical-expr/src/window/built_in.rs | 67 +-
.../src/window/built_in_window_function_expr.rs | 6 +
datafusion/physical-expr/src/window/cume_dist.rs | 25 +-
datafusion/physical-expr/src/window/lead_lag.rs | 27 +-
datafusion/physical-expr/src/window/nth_value.rs | 14 +
datafusion/physical-expr/src/window/ntile.rs | 9 +-
.../src/window/partition_evaluator.rs | 55 +-
datafusion/physical-expr/src/window/rank.rs | 28 +-
datafusion/physical-expr/src/window/row_number.rs | 18 +-
.../physical-expr/src/window/sliding_aggregate.rs | 111 +--
datafusion/physical-expr/src/window/window_expr.rs | 51 +-
.../physical-expr/src/window/window_frame_state.rs | 15 +-
31 files changed, 2017 insertions(+), 378 deletions(-)
diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs
index 60d693249..392fa3f25 100644
--- a/datafusion/common/src/lib.rs
+++ b/datafusion/common/src/lib.rs
@@ -30,6 +30,7 @@ pub mod stats;
mod table_reference;
pub mod test_util;
+use arrow::compute::SortOptions;
pub use column::Column;
pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema};
pub use error::{field_not_found, DataFusionError, Result, SchemaError};
@@ -63,3 +64,12 @@ macro_rules! downcast_value {
})?
}};
}
+
+/// Computes the "reverse" of given `SortOptions`.
+// TODO: If/when arrow supports `!` for `SortOptions`, we can remove this.
+pub fn reverse_sort_options(options: SortOptions) -> SortOptions {
+ SortOptions {
+ descending: !options.descending,
+ nulls_first: !options.nulls_first,
+ }
+}
diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs
index 098dafdc0..978bde2a2 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -100,6 +100,7 @@ use url::Url;
use crate::catalog::listing_schema::ListingSchemaProvider;
use crate::datasource::object_store::ObjectStoreUrl;
use crate::execution::memory_pool::MemoryPool;
+use crate::physical_optimizer::optimize_sorts::OptimizeSorts;
use uuid::Uuid;
use super::options::{
@@ -1580,6 +1581,12 @@ impl SessionState {
// To make sure the SinglePartition is satisfied, run the BasicEnforcement again, originally it was the AddCoalescePartitionsExec here.
physical_optimizers.push(Arc::new(BasicEnforcement::new()));
+ // `BasicEnforcement` stage conservatively inserts `SortExec`s to satisfy ordering requirements.
+ // However, a deeper analysis may sometimes reveal that such a `SortExec` is actually unnecessary.
+ // These cases typically arise when we have reversible `WindowAggExec`s or deep subqueries. The
+ // rule below performs this analysis and removes unnecessary `SortExec`s.
+ physical_optimizers.push(Arc::new(OptimizeSorts::new()));
+
let mut this = SessionState {
session_id,
optimizer: Optimizer::new(),
diff --git a/datafusion/core/src/physical_optimizer/enforcement.rs b/datafusion/core/src/physical_optimizer/enforcement.rs
index 06832ac24..4a496a3ef 100644
--- a/datafusion/core/src/physical_optimizer/enforcement.rs
+++ b/datafusion/core/src/physical_optimizer/enforcement.rs
@@ -20,6 +20,7 @@
//!
use crate::config::OPT_TOP_DOWN_JOIN_KEY_REORDERING;
use crate::error::Result;
+use crate::physical_optimizer::utils::{add_sort_above_child, ordering_satisfy};
use crate::physical_optimizer::PhysicalOptimizerRule;
use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
@@ -29,8 +30,7 @@ use crate::physical_plan::joins::{
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::repartition::RepartitionExec;
use crate::physical_plan::rewrite::TreeNodeRewritable;
-use crate::physical_plan::sorts::sort::SortExec;
-use crate::physical_plan::sorts::sort::SortOptions;
+use crate::physical_plan::sorts::sort::{SortExec, SortOptions};
use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use crate::physical_plan::windows::WindowAggExec;
use crate::physical_plan::Partitioning;
@@ -42,9 +42,8 @@ use datafusion_physical_expr::equivalence::EquivalenceProperties;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::expressions::NoOp;
use datafusion_physical_expr::{
- expr_list_eq_strict_order, normalize_expr_with_equivalence_properties,
- normalize_sort_expr_with_equivalence_properties, AggregateExpr, PhysicalExpr,
- PhysicalSortExpr,
+ expr_list_eq_strict_order, normalize_expr_with_equivalence_properties, AggregateExpr,
+ PhysicalExpr,
};
use std::collections::HashMap;
use std::sync::Arc;
@@ -919,9 +918,7 @@ fn ensure_distribution_and_ordering(
Ok(child)
} else {
let sort_expr = required.unwrap().to_vec();
- Ok(Arc::new(SortExec::new_with_partitioning(
- sort_expr, child, true, None,
- )) as Arc<dyn ExecutionPlan>)
+ add_sort_above_child(&child, sort_expr)
}
})
.collect();
@@ -929,61 +926,6 @@ fn ensure_distribution_and_ordering(
with_new_children_if_necessary(plan, new_children?)
}
-/// Check the required ordering requirements are satisfied by the provided PhysicalSortExprs.
-fn ordering_satisfy<F: FnOnce() -> EquivalenceProperties>(
- provided: Option<&[PhysicalSortExpr]>,
- required: Option<&[PhysicalSortExpr]>,
- equal_properties: F,
-) -> bool {
- match (provided, required) {
- (_, None) => true,
- (None, Some(_)) => false,
- (Some(provided), Some(required)) => {
- if required.len() > provided.len() {
- false
- } else {
- let fast_match = required
- .iter()
- .zip(provided.iter())
- .all(|(order1, order2)| order1.eq(order2));
-
- if !fast_match {
- let eq_properties = equal_properties();
- let eq_classes = eq_properties.classes();
- if !eq_classes.is_empty() {
- let normalized_required_exprs = required
- .iter()
- .map(|e| {
- normalize_sort_expr_with_equivalence_properties(
- e.clone(),
- eq_classes,
- )
- })
- .collect::<Vec<_>>();
- let normalized_provided_exprs = provided
- .iter()
- .map(|e| {
- normalize_sort_expr_with_equivalence_properties(
- e.clone(),
- eq_classes,
- )
- })
- .collect::<Vec<_>>();
- normalized_required_exprs
- .iter()
- .zip(normalized_provided_exprs.iter())
- .all(|(order1, order2)| order1.eq(order2))
- } else {
- fast_match
- }
- } else {
- fast_match
- }
- }
- }
- }
-}
-
#[derive(Debug, Clone)]
struct JoinKeyPairs {
left_keys: Vec<Arc<dyn PhysicalExpr>>,
@@ -1063,10 +1005,10 @@ mod tests {
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_expr::logical_plan::JoinType;
use datafusion_expr::Operator;
- use datafusion_physical_expr::expressions::binary;
- use datafusion_physical_expr::expressions::lit;
- use datafusion_physical_expr::expressions::Column;
- use datafusion_physical_expr::{expressions, PhysicalExpr};
+ use datafusion_physical_expr::{
+ expressions, expressions::binary, expressions::lit, expressions::Column,
+ PhysicalExpr, PhysicalSortExpr,
+ };
use std::ops::Deref;
use super::*;
diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs
index 36b00a0e0..0fd0600fb 100644
--- a/datafusion/core/src/physical_optimizer/mod.rs
+++ b/datafusion/core/src/physical_optimizer/mod.rs
@@ -22,6 +22,7 @@ pub mod aggregate_statistics;
pub mod coalesce_batches;
pub mod enforcement;
pub mod join_selection;
+pub mod optimize_sorts;
pub mod optimizer;
pub mod pruning;
pub mod repartition;
diff --git a/datafusion/core/src/physical_optimizer/optimize_sorts.rs b/datafusion/core/src/physical_optimizer/optimize_sorts.rs
new file mode 100644
index 000000000..cb421b7b8
--- /dev/null
+++ b/datafusion/core/src/physical_optimizer/optimize_sorts.rs
@@ -0,0 +1,887 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! OptimizeSorts optimizer rule inspects [SortExec]s in the given physical
+//! plan and removes the ones it can prove unnecessary. The rule can work on
+//! valid *and* invalid physical plans with respect to sorting requirements,
+//! but always produces a valid physical plan in this sense.
+//!
+//! A non-realistic but easy to follow example: Assume that we somehow get the fragment
+//! "SortExec: [nullable_col@0 ASC]",
+//! " SortExec: [non_nullable_col@1 ASC]",
+//! in the physical plan. The first sort is unnecessary since its result is overwritten
+//! by another SortExec. Therefore, this rule removes it from the physical plan.
+use crate::error::Result;
+use crate::physical_optimizer::utils::{
+ add_sort_above_child, ordering_satisfy, ordering_satisfy_concrete,
+};
+use crate::physical_optimizer::PhysicalOptimizerRule;
+use crate::physical_plan::rewrite::TreeNodeRewritable;
+use crate::physical_plan::sorts::sort::SortExec;
+use crate::physical_plan::windows::WindowAggExec;
+use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan};
+use crate::prelude::SessionConfig;
+use arrow::datatypes::SchemaRef;
+use datafusion_common::{reverse_sort_options, DataFusionError};
+use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
+use itertools::izip;
+use std::iter::zip;
+use std::sync::Arc;
+
+/// This rule inspects SortExec's in the given physical plan and removes the
+/// ones it can prove unnecessary.
+#[derive(Default)]
+pub struct OptimizeSorts {}
+
+impl OptimizeSorts {
+ #[allow(missing_docs)]
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+/// This is a "data class" we use within the [OptimizeSorts] rule that
+/// tracks the closest `SortExec` descendant for every child of a plan.
+#[derive(Debug, Clone)]
+struct PlanWithCorrespondingSort {
+ plan: Arc<dyn ExecutionPlan>,
+ // For every child, keep a vector of `ExecutionPlan`s starting from the
+ // closest `SortExec` till the current plan. The first index of the tuple is
+ // the child index of the plan -- we need this information as we make updates.
+ sort_onwards: Vec<Vec<(usize, Arc<dyn ExecutionPlan>)>>,
+}
+
+impl PlanWithCorrespondingSort {
+ pub fn new(plan: Arc<dyn ExecutionPlan>) -> Self {
+ let length = plan.children().len();
+ PlanWithCorrespondingSort {
+ plan,
+ sort_onwards: vec![vec![]; length],
+ }
+ }
+
+ pub fn children(&self) -> Vec<PlanWithCorrespondingSort> {
+ self.plan
+ .children()
+ .into_iter()
+ .map(|child| PlanWithCorrespondingSort::new(child))
+ .collect()
+ }
+}
+
+impl TreeNodeRewritable for PlanWithCorrespondingSort {
+ fn map_children<F>(self, transform: F) -> Result<Self>
+ where
+ F: FnMut(Self) -> Result<Self>,
+ {
+ let children = self.children();
+ if children.is_empty() {
+ Ok(self)
+ } else {
+ let children_requirements = children
+ .into_iter()
+ .map(transform)
+ .collect::<Result<Vec<_>>>()?;
+ let children_plans = children_requirements
+ .iter()
+ .map(|elem| elem.plan.clone())
+ .collect::<Vec<_>>();
+ let sort_onwards = children_requirements
+ .iter()
+ .map(|item| {
+ if item.sort_onwards.is_empty() {
+ vec![]
+ } else {
+ // TODO: When `maintains_input_order` returns Vec<bool>,
+ // pass the order-enforcing sort upwards.
+ item.sort_onwards[0].clone()
+ }
+ })
+ .collect::<Vec<_>>();
+ let plan = with_new_children_if_necessary(self.plan, children_plans)?;
+ Ok(PlanWithCorrespondingSort { plan, sort_onwards })
+ }
+ }
+}
+
+impl PhysicalOptimizerRule for OptimizeSorts {
+ fn optimize(
+ &self,
+ plan: Arc<dyn ExecutionPlan>,
+ _config: &SessionConfig,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ // Execute a post-order traversal to adjust input key ordering:
+ let plan_requirements = PlanWithCorrespondingSort::new(plan);
+ let adjusted = plan_requirements.transform_up(&optimize_sorts)?;
+ Ok(adjusted.plan)
+ }
+
+ fn name(&self) -> &str {
+ "OptimizeSorts"
+ }
+
+ fn schema_check(&self) -> bool {
+ true
+ }
+}
+
+fn optimize_sorts(
+ requirements: PlanWithCorrespondingSort,
+) -> Result<Option<PlanWithCorrespondingSort>> {
+ // Perform naive analysis at the beginning -- remove already-satisfied sorts:
+ if let Some(result) = analyze_immediate_sort_removal(&requirements)? {
+ return Ok(Some(result));
+ }
+ let plan = &requirements.plan;
+ let mut new_children = plan.children().clone();
+ let mut new_onwards = requirements.sort_onwards.clone();
+ for (idx, (child, sort_onwards, required_ordering)) in izip!(
+ new_children.iter_mut(),
+ new_onwards.iter_mut(),
+ plan.required_input_ordering()
+ )
+ .enumerate()
+ {
+ let physical_ordering = child.output_ordering();
+ match (required_ordering, physical_ordering) {
+ (Some(required_ordering), Some(physical_ordering)) => {
+ let is_ordering_satisfied = ordering_satisfy_concrete(
+ physical_ordering,
+ required_ordering,
+ || child.equivalence_properties(),
+ );
+ if !is_ordering_satisfied {
+ // Make sure we preserve the ordering requirements:
+ update_child_to_remove_unnecessary_sort(child, sort_onwards)?;
+ let sort_expr = required_ordering.to_vec();
+ *child = add_sort_above_child(child, sort_expr)?;
+ sort_onwards.push((idx, child.clone()))
+ } else if let [first, ..] = sort_onwards.as_slice() {
+ // The ordering requirement is met, we can analyze if there is an unnecessary sort:
+ let sort_any = first.1.clone();
+ let sort_exec = convert_to_sort_exec(&sort_any)?;
+ let sort_output_ordering = sort_exec.output_ordering();
+ let sort_input_ordering = sort_exec.input().output_ordering();
+ // Simple analysis: Does the input of the sort in question already satisfy the ordering requirements?
+ if ordering_satisfy(sort_input_ordering, sort_output_ordering, || {
+ sort_exec.input().equivalence_properties()
+ }) {
+ update_child_to_remove_unnecessary_sort(child, sort_onwards)?;
+ } else if let Some(window_agg_exec) =
+ requirements.plan.as_any().downcast_ref::<WindowAggExec>()
+ {
+ // For window expressions, we can remove some sorts when we can
+ // calculate the result in reverse:
+ if let Some(res) = analyze_window_sort_removal(
+ window_agg_exec,
+ sort_exec,
+ sort_onwards,
+ )? {
+ return Ok(Some(res));
+ }
+ }
+ // TODO: Once we can ensure that required ordering information propagates with
+ // necessary lineage information, compare `sort_input_ordering` and `required_ordering`.
+ // This will enable us to handle cases such as (a,b) -> Sort -> (a,b,c) -> Required(a,b).
+ // Currently, we can not remove such sorts.
+ }
+ }
+ (Some(required), None) => {
+ // Ordering requirement is not met, we should add a SortExec to the plan.
+ let sort_expr = required.to_vec();
+ *child = add_sort_above_child(child, sort_expr)?;
+ *sort_onwards = vec![(idx, child.clone())];
+ }
+ (None, Some(_)) => {
+ // We have a SortExec whose effect may be neutralized by a order-imposing
+ // operator. In this case, remove this sort:
+ if !requirements.plan.maintains_input_order() {
+ update_child_to_remove_unnecessary_sort(child, sort_onwards)?;
+ }
+ }
+ (None, None) => {}
+ }
+ }
+ if plan.children().is_empty() {
+ Ok(Some(requirements))
+ } else {
+ let new_plan = requirements.plan.with_new_children(new_children)?;
+ for (idx, (trace, required_ordering)) in new_onwards
+ .iter_mut()
+ .zip(new_plan.required_input_ordering())
+ .enumerate()
+ .take(new_plan.children().len())
+ {
+ // TODO: When `maintains_input_order` returns a `Vec<bool>`, use corresponding index.
+ if new_plan.maintains_input_order()
+ && required_ordering.is_none()
+ && !trace.is_empty()
+ {
+ trace.push((idx, new_plan.clone()));
+ } else {
+ trace.clear();
+ if new_plan.as_any().is::<SortExec>() {
+ trace.push((idx, new_plan.clone()));
+ }
+ }
+ }
+ Ok(Some(PlanWithCorrespondingSort {
+ plan: new_plan,
+ sort_onwards: new_onwards,
+ }))
+ }
+}
+
+/// Analyzes a given `SortExec` to determine whether its input already has
+/// a finer ordering than this `SortExec` enforces.
+fn analyze_immediate_sort_removal(
+ requirements: &PlanWithCorrespondingSort,
+) -> Result<Option<PlanWithCorrespondingSort>> {
+ if let Some(sort_exec) = requirements.plan.as_any().downcast_ref::<SortExec>() {
+ // If this sort is unnecessary, we should remove it:
+ if ordering_satisfy(
+ sort_exec.input().output_ordering(),
+ sort_exec.output_ordering(),
+ || sort_exec.input().equivalence_properties(),
+ ) {
+ // Since we know that a `SortExec` has exactly one child,
+ // we can use the zero index safely:
+ let mut new_onwards = requirements.sort_onwards[0].to_vec();
+ if !new_onwards.is_empty() {
+ new_onwards.pop();
+ }
+ return Ok(Some(PlanWithCorrespondingSort {
+ plan: sort_exec.input().clone(),
+ sort_onwards: vec![new_onwards],
+ }));
+ }
+ }
+ Ok(None)
+}
+
+/// Analyzes a `WindowAggExec` to determine whether it may allow removing a sort.
+fn analyze_window_sort_removal(
+ window_agg_exec: &WindowAggExec,
+ sort_exec: &SortExec,
+ sort_onward: &mut Vec<(usize, Arc<dyn ExecutionPlan>)>,
+) -> Result<Option<PlanWithCorrespondingSort>> {
+ let required_ordering = sort_exec.output_ordering().ok_or_else(|| {
+ DataFusionError::Plan("A SortExec should have output ordering".to_string())
+ })?;
+ let physical_ordering = sort_exec.input().output_ordering();
+ let physical_ordering = if let Some(physical_ordering) = physical_ordering {
+ physical_ordering
+ } else {
+ // If there is no physical ordering, there is no way to remove a sort -- immediately return:
+ return Ok(None);
+ };
+ let window_expr = window_agg_exec.window_expr();
+ let (can_skip_sorting, should_reverse) = can_skip_sort(
+ window_expr[0].partition_by(),
+ required_ordering,
+ &sort_exec.input().schema(),
+ physical_ordering,
+ )?;
+ if can_skip_sorting {
+ let new_window_expr = if should_reverse {
+ window_expr
+ .iter()
+ .map(|e| e.get_reverse_expr())
+ .collect::<Option<Vec<_>>>()
+ } else {
+ Some(window_expr.to_vec())
+ };
+ if let Some(window_expr) = new_window_expr {
+ let new_child = remove_corresponding_sort_from_sub_plan(sort_onward)?;
+ let new_schema = new_child.schema();
+ let new_plan = Arc::new(WindowAggExec::try_new(
+ window_expr,
+ new_child,
+ new_schema,
+ window_agg_exec.partition_keys.clone(),
+ Some(physical_ordering.to_vec()),
+ )?);
+ return Ok(Some(PlanWithCorrespondingSort::new(new_plan)));
+ }
+ }
+ Ok(None)
+}
+
+/// Updates child to remove the unnecessary sorting below it.
+fn update_child_to_remove_unnecessary_sort(
+ child: &mut Arc<dyn ExecutionPlan>,
+ sort_onwards: &mut Vec<(usize, Arc<dyn ExecutionPlan>)>,
+) -> Result<()> {
+ if !sort_onwards.is_empty() {
+ *child = remove_corresponding_sort_from_sub_plan(sort_onwards)?;
+ }
+ Ok(())
+}
+
+/// Converts an [ExecutionPlan] trait object to a [SortExec] when possible.
+fn convert_to_sort_exec(sort_any: &Arc<dyn ExecutionPlan>) -> Result<&SortExec> {
+ sort_any.as_any().downcast_ref::<SortExec>().ok_or_else(|| {
+ DataFusionError::Plan("Given ExecutionPlan is not a SortExec".to_string())
+ })
+}
+
+/// Removes the sort from the plan in `sort_onwards`.
+fn remove_corresponding_sort_from_sub_plan(
+ sort_onwards: &mut Vec<(usize, Arc<dyn ExecutionPlan>)>,
+) -> Result<Arc<dyn ExecutionPlan>> {
+ let (sort_child_idx, sort_any) = sort_onwards[0].clone();
+ let sort_exec = convert_to_sort_exec(&sort_any)?;
+ let mut prev_layer = sort_exec.input().clone();
+ let mut prev_child_idx = sort_child_idx;
+ // In the loop below, se start from 1 as the first one is a SortExec
+ // and we are removing it from the plan.
+ for (child_idx, layer) in sort_onwards.iter().skip(1) {
+ let mut children = layer.children();
+ children[prev_child_idx] = prev_layer;
+ prev_layer = layer.clone().with_new_children(children)?;
+ prev_child_idx = *child_idx;
+ }
+ // We have removed the sort, hence empty the sort_onwards:
+ sort_onwards.clear();
+ Ok(prev_layer)
+}
+
+#[derive(Debug)]
+/// This structure stores extra column information required to remove unnecessary sorts.
+pub struct ColumnInfo {
+ is_aligned: bool,
+ reverse: bool,
+ is_partition: bool,
+}
+
+/// Compares physical ordering and required ordering of all `PhysicalSortExpr`s and returns a tuple.
+/// The first element indicates whether these `PhysicalSortExpr`s can be removed from the physical plan.
+/// The second element is a flag indicating whether we should reverse the sort direction in order to
+/// remove physical sort expressions from the plan.
+pub fn can_skip_sort(
+ partition_keys: &[Arc<dyn PhysicalExpr>],
+ required: &[PhysicalSortExpr],
+ input_schema: &SchemaRef,
+ physical_ordering: &[PhysicalSortExpr],
+) -> Result<(bool, bool)> {
+ if required.len() > physical_ordering.len() {
+ return Ok((false, false));
+ }
+ let mut col_infos = vec![];
+ for (sort_expr, physical_expr) in zip(required, physical_ordering) {
+ let column = sort_expr.expr.clone();
+ let is_partition = partition_keys.iter().any(|e| e.eq(&column));
+ let (is_aligned, reverse) =
+ check_alignment(input_schema, physical_expr, sort_expr);
+ col_infos.push(ColumnInfo {
+ is_aligned,
+ reverse,
+ is_partition,
+ });
+ }
+ let partition_by_sections = col_infos
+ .iter()
+ .filter(|elem| elem.is_partition)
+ .collect::<Vec<_>>();
+ let can_skip_partition_bys = if partition_by_sections.is_empty() {
+ true
+ } else {
+ let first_reverse = partition_by_sections[0].reverse;
+ let can_skip_partition_bys = partition_by_sections
+ .iter()
+ .all(|c| c.is_aligned && c.reverse == first_reverse);
+ can_skip_partition_bys
+ };
+ let order_by_sections = col_infos
+ .iter()
+ .filter(|elem| !elem.is_partition)
+ .collect::<Vec<_>>();
+ let (can_skip_order_bys, should_reverse_order_bys) = if order_by_sections.is_empty() {
+ (true, false)
+ } else {
+ let first_reverse = order_by_sections[0].reverse;
+ let can_skip_order_bys = order_by_sections
+ .iter()
+ .all(|c| c.is_aligned && c.reverse == first_reverse);
+ (can_skip_order_bys, first_reverse)
+ };
+ let can_skip = can_skip_order_bys && can_skip_partition_bys;
+ Ok((can_skip, should_reverse_order_bys))
+}
+
+/// Compares `physical_ordering` and `required` ordering, returns a tuple
+/// indicating (1) whether this column requires sorting, and (2) whether we
+/// should reverse the window expression in order to avoid sorting.
+fn check_alignment(
+ input_schema: &SchemaRef,
+ physical_ordering: &PhysicalSortExpr,
+ required: &PhysicalSortExpr,
+) -> (bool, bool) {
+ if required.expr.eq(&physical_ordering.expr) {
+ let nullable = required.expr.nullable(input_schema).unwrap();
+ let physical_opts = physical_ordering.options;
+ let required_opts = required.options;
+ let is_reversed = if nullable {
+ physical_opts == reverse_sort_options(required_opts)
+ } else {
+ // If the column is not nullable, NULLS FIRST/LAST is not important.
+ physical_opts.descending != required_opts.descending
+ };
+ let can_skip = !nullable || is_reversed || (physical_opts == required_opts);
+ (can_skip, is_reversed)
+ } else {
+ (false, false)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::physical_plan::displayable;
+ use crate::physical_plan::filter::FilterExec;
+ use crate::physical_plan::memory::MemoryExec;
+ use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
+ use crate::physical_plan::windows::create_window_expr;
+ use crate::prelude::SessionContext;
+ use arrow::compute::SortOptions;
+ use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+ use datafusion_common::Result;
+ use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction};
+ use datafusion_physical_expr::expressions::{col, NotExpr};
+ use datafusion_physical_expr::PhysicalSortExpr;
+ use std::sync::Arc;
+
+ fn create_test_schema() -> Result<SchemaRef> {
+ let nullable_column = Field::new("nullable_col", DataType::Int32, true);
+ let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false);
+ let schema = Arc::new(Schema::new(vec![nullable_column, non_nullable_column]));
+
+ Ok(schema)
+ }
+
+ #[tokio::test]
+ async fn test_is_column_aligned_nullable() -> Result<()> {
+ let schema = create_test_schema()?;
+ let params = vec![
+ ((true, true), (false, false), (true, true)),
+ ((true, true), (false, true), (false, false)),
+ ((true, true), (true, false), (false, false)),
+ ((true, false), (false, true), (true, true)),
+ ((true, false), (false, false), (false, false)),
+ ((true, false), (true, true), (false, false)),
+ ];
+ for (
+ (physical_desc, physical_nulls_first),
+ (req_desc, req_nulls_first),
+ (is_aligned_expected, reverse_expected),
+ ) in params
+ {
+ let physical_ordering = PhysicalSortExpr {
+ expr: col("nullable_col", &schema)?,
+ options: SortOptions {
+ descending: physical_desc,
+ nulls_first: physical_nulls_first,
+ },
+ };
+ let required_ordering = PhysicalSortExpr {
+ expr: col("nullable_col", &schema)?,
+ options: SortOptions {
+ descending: req_desc,
+ nulls_first: req_nulls_first,
+ },
+ };
+ let (is_aligned, reverse) =
+ check_alignment(&schema, &physical_ordering, &required_ordering);
+ assert_eq!(is_aligned, is_aligned_expected);
+ assert_eq!(reverse, reverse_expected);
+ }
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_is_column_aligned_non_nullable() -> Result<()> {
+ let schema = create_test_schema()?;
+
+ let params = vec![
+ ((true, true), (false, false), (true, true)),
+ ((true, true), (false, true), (true, true)),
+ ((true, true), (true, false), (true, false)),
+ ((true, false), (false, true), (true, true)),
+ ((true, false), (false, false), (true, true)),
+ ((true, false), (true, true), (true, false)),
+ ];
+ for (
+ (physical_desc, physical_nulls_first),
+ (req_desc, req_nulls_first),
+ (is_aligned_expected, reverse_expected),
+ ) in params
+ {
+ let physical_ordering = PhysicalSortExpr {
+ expr: col("non_nullable_col", &schema)?,
+ options: SortOptions {
+ descending: physical_desc,
+ nulls_first: physical_nulls_first,
+ },
+ };
+ let required_ordering = PhysicalSortExpr {
+ expr: col("non_nullable_col", &schema)?,
+ options: SortOptions {
+ descending: req_desc,
+ nulls_first: req_nulls_first,
+ },
+ };
+ let (is_aligned, reverse) =
+ check_alignment(&schema, &physical_ordering, &required_ordering);
+ assert_eq!(is_aligned, is_aligned_expected);
+ assert_eq!(reverse, reverse_expected);
+ }
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_remove_unnecessary_sort() -> Result<()> {
+ let session_ctx = SessionContext::new();
+ let conf = session_ctx.copied_config();
+ let schema = create_test_schema()?;
+ let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?)
+ as Arc<dyn ExecutionPlan>;
+ let sort_exprs = vec![PhysicalSortExpr {
+ expr: col("non_nullable_col", schema.as_ref()).unwrap(),
+ options: SortOptions::default(),
+ }];
+ let sort_exec = Arc::new(SortExec::try_new(sort_exprs, source, None)?)
+ as Arc<dyn ExecutionPlan>;
+ let sort_exprs = vec![PhysicalSortExpr {
+ expr: col("nullable_col", schema.as_ref()).unwrap(),
+ options: SortOptions::default(),
+ }];
+ let physical_plan = Arc::new(SortExec::try_new(sort_exprs, sort_exec, None)?)
+ as Arc<dyn ExecutionPlan>;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ let expected = {
+ vec![
+ "SortExec: [nullable_col@0 ASC]",
+ " SortExec: [non_nullable_col@1 ASC]",
+ ]
+ };
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ let optimized_physical_plan =
+ OptimizeSorts::new().optimize(physical_plan, &conf)?;
+ let formatted = displayable(optimized_physical_plan.as_ref())
+ .indent()
+ .to_string();
+ let expected = { vec!["SortExec: [nullable_col@0 ASC]"] };
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> {
+ let session_ctx = SessionContext::new();
+ let conf = session_ctx.copied_config();
+ let schema = create_test_schema()?;
+ let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?)
+ as Arc<dyn ExecutionPlan>;
+ let sort_exprs = vec![PhysicalSortExpr {
+ expr: col("non_nullable_col", source.schema().as_ref()).unwrap(),
+ options: SortOptions {
+ descending: true,
+ nulls_first: true,
+ },
+ }];
+ let sort_exec = Arc::new(SortExec::try_new(sort_exprs.clone(), source, None)?)
+ as Arc<dyn ExecutionPlan>;
+ let window_agg_exec = Arc::new(WindowAggExec::try_new(
+ vec![create_window_expr(
+ &WindowFunction::AggregateFunction(AggregateFunction::Count),
+ "count".to_owned(),
+ &[col("non_nullable_col", &schema)?],
+ &[],
+ &sort_exprs,
+ Arc::new(WindowFrame::new(true)),
+ schema.as_ref(),
+ )?],
+ sort_exec.clone(),
+ sort_exec.schema(),
+ vec![],
+ Some(sort_exprs),
+ )?) as Arc<dyn ExecutionPlan>;
+ let sort_exprs = vec![PhysicalSortExpr {
+ expr: col("non_nullable_col", window_agg_exec.schema().as_ref()).unwrap(),
+ options: SortOptions {
+ descending: false,
+ nulls_first: false,
+ },
+ }];
+ let sort_exec = Arc::new(SortExec::try_new(
+ sort_exprs.clone(),
+ window_agg_exec,
+ None,
+ )?) as Arc<dyn ExecutionPlan>;
+ // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before
+ let filter_exec = Arc::new(FilterExec::try_new(
+ Arc::new(NotExpr::new(
+ col("non_nullable_col", schema.as_ref()).unwrap(),
+ )),
+ sort_exec,
+ )?) as Arc<dyn ExecutionPlan>;
+ // let filter_exec = sort_exec;
+ let window_agg_exec = Arc::new(WindowAggExec::try_new(
+ vec![create_window_expr(
+ &WindowFunction::AggregateFunction(AggregateFunction::Count),
+ "count".to_owned(),
+ &[col("non_nullable_col", &schema)?],
+ &[],
+ &sort_exprs,
+ Arc::new(WindowFrame::new(true)),
+ schema.as_ref(),
+ )?],
+ filter_exec.clone(),
+ filter_exec.schema(),
+ vec![],
+ Some(sort_exprs),
+ )?) as Arc<dyn ExecutionPlan>;
+ let physical_plan = window_agg_exec;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ let expected = {
+ vec![
+ "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]",
+ " FilterExec: NOT non_nullable_col@1",
+ " SortExec: [non_nullable_col@2 ASC NULLS LAST]",
+ " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]",
+ " SortExec: [non_nullable_col@1 DESC]",
+ " MemoryExec: partitions=0, partition_sizes=[]",
+ ]
+ };
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ let optimized_physical_plan =
+ OptimizeSorts::new().optimize(physical_plan, &conf)?;
+ let formatted = displayable(optimized_physical_plan.as_ref())
+ .indent()
+ .to_string();
+ let expected = {
+ vec![
+ "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL) }]",
+ " FilterExec: NOT non_nullable_col@1",
+ " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]",
+ " SortExec: [non_nullable_col@1 DESC]",
+ " MemoryExec: partitions=0, partition_sizes=[]",
+ ]
+ };
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_add_required_sort() -> Result<()> {
+ let session_ctx = SessionContext::new();
+ let conf = session_ctx.copied_config();
+ let schema = create_test_schema()?;
+ let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?)
+ as Arc<dyn ExecutionPlan>;
+ let sort_exprs = vec![PhysicalSortExpr {
+ expr: col("nullable_col", schema.as_ref()).unwrap(),
+ options: SortOptions::default(),
+ }];
+ let physical_plan = Arc::new(SortPreservingMergeExec::new(sort_exprs, source))
+ as Arc<dyn ExecutionPlan>;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ let expected = { vec!["SortPreservingMergeExec: [nullable_col@0 ASC]"] };
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ let optimized_physical_plan =
+ OptimizeSorts::new().optimize(physical_plan, &conf)?;
+ let formatted = displayable(optimized_physical_plan.as_ref())
+ .indent()
+ .to_string();
+ let expected = {
+ vec![
+ "SortPreservingMergeExec: [nullable_col@0 ASC]",
+ " SortExec: [nullable_col@0 ASC]",
+ ]
+ };
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_remove_unnecessary_sort1() -> Result<()> {
+ let session_ctx = SessionContext::new();
+ let conf = session_ctx.copied_config();
+ let schema = create_test_schema()?;
+ let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?)
+ as Arc<dyn ExecutionPlan>;
+ let sort_exprs = vec![PhysicalSortExpr {
+ expr: col("nullable_col", schema.as_ref()).unwrap(),
+ options: SortOptions::default(),
+ }];
+ let sort_exec = Arc::new(SortExec::try_new(sort_exprs.clone(), source, None)?)
+ as Arc<dyn ExecutionPlan>;
+ let sort_preserving_merge_exec =
+ Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec))
+ as Arc<dyn ExecutionPlan>;
+ let sort_exprs = vec![PhysicalSortExpr {
+ expr: col("nullable_col", schema.as_ref()).unwrap(),
+ options: SortOptions::default(),
+ }];
+ let sort_exec = Arc::new(SortExec::try_new(
+ sort_exprs.clone(),
+ sort_preserving_merge_exec,
+ None,
+ )?) as Arc<dyn ExecutionPlan>;
+ let sort_preserving_merge_exec =
+ Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec))
+ as Arc<dyn ExecutionPlan>;
+ let physical_plan = sort_preserving_merge_exec;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ let expected = {
+ vec![
+ "SortPreservingMergeExec: [nullable_col@0 ASC]",
+ " SortExec: [nullable_col@0 ASC]",
+ " SortPreservingMergeExec: [nullable_col@0 ASC]",
+ " SortExec: [nullable_col@0 ASC]",
+ " MemoryExec: partitions=0, partition_sizes=[]",
+ ]
+ };
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ let optimized_physical_plan =
+ OptimizeSorts::new().optimize(physical_plan, &conf)?;
+ let formatted = displayable(optimized_physical_plan.as_ref())
+ .indent()
+ .to_string();
+ let expected = {
+ vec![
+ "SortPreservingMergeExec: [nullable_col@0 ASC]",
+ " SortPreservingMergeExec: [nullable_col@0 ASC]",
+ " SortExec: [nullable_col@0 ASC]",
+ " MemoryExec: partitions=0, partition_sizes=[]",
+ ]
+ };
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_change_wrong_sorting() -> Result<()> {
+ let session_ctx = SessionContext::new();
+ let conf = session_ctx.copied_config();
+ let schema = create_test_schema()?;
+ let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?)
+ as Arc<dyn ExecutionPlan>;
+ let sort_exprs = vec![
+ PhysicalSortExpr {
+ expr: col("nullable_col", schema.as_ref()).unwrap(),
+ options: SortOptions::default(),
+ },
+ PhysicalSortExpr {
+ expr: col("non_nullable_col", schema.as_ref()).unwrap(),
+ options: SortOptions::default(),
+ },
+ ];
+ let sort_exec = Arc::new(SortExec::try_new(
+ vec![sort_exprs[0].clone()],
+ source,
+ None,
+ )?) as Arc<dyn ExecutionPlan>;
+ let sort_preserving_merge_exec =
+ Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec))
+ as Arc<dyn ExecutionPlan>;
+ let physical_plan = sort_preserving_merge_exec;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ let expected = {
+ vec![
+ "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]",
+ " SortExec: [nullable_col@0 ASC]",
+ " MemoryExec: partitions=0, partition_sizes=[]",
+ ]
+ };
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ let optimized_physical_plan =
+ OptimizeSorts::new().optimize(physical_plan, &conf)?;
+ let formatted = displayable(optimized_physical_plan.as_ref())
+ .indent()
+ .to_string();
+ let expected = {
+ vec![
+ "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]",
+ " SortExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]",
+ " MemoryExec: partitions=0, partition_sizes=[]",
+ ]
+ };
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ Ok(())
+ }
+}
diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs
index 4aceb776d..8f1fe2d08 100644
--- a/datafusion/core/src/physical_optimizer/utils.rs
+++ b/datafusion/core/src/physical_optimizer/utils.rs
@@ -21,7 +21,12 @@ use super::optimizer::PhysicalOptimizerRule;
use crate::execution::context::SessionConfig;
use crate::error::Result;
+use crate::physical_plan::sorts::sort::SortExec;
use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan};
+use datafusion_physical_expr::{
+ normalize_sort_expr_with_equivalence_properties, EquivalenceProperties,
+ PhysicalSortExpr,
+};
use std::sync::Arc;
/// Convenience rule for writing optimizers: recursively invoke
@@ -45,3 +50,73 @@ pub fn optimize_children(
with_new_children_if_necessary(plan, children)
}
}
+
+/// Checks whether given ordering requirements are satisfied by provided [PhysicalSortExpr]s.
+pub fn ordering_satisfy<F: FnOnce() -> EquivalenceProperties>(
+ provided: Option<&[PhysicalSortExpr]>,
+ required: Option<&[PhysicalSortExpr]>,
+ equal_properties: F,
+) -> bool {
+ match (provided, required) {
+ (_, None) => true,
+ (None, Some(_)) => false,
+ (Some(provided), Some(required)) => {
+ ordering_satisfy_concrete(provided, required, equal_properties)
+ }
+ }
+}
+
+pub fn ordering_satisfy_concrete<F: FnOnce() -> EquivalenceProperties>(
+ provided: &[PhysicalSortExpr],
+ required: &[PhysicalSortExpr],
+ equal_properties: F,
+) -> bool {
+ if required.len() > provided.len() {
+ false
+ } else if required
+ .iter()
+ .zip(provided.iter())
+ .all(|(order1, order2)| order1.eq(order2))
+ {
+ true
+ } else if let eq_classes @ [_, ..] = equal_properties().classes() {
+ let normalized_required_exprs = required
+ .iter()
+ .map(|e| {
+ normalize_sort_expr_with_equivalence_properties(e.clone(), eq_classes)
+ })
+ .collect::<Vec<_>>();
+ let normalized_provided_exprs = provided
+ .iter()
+ .map(|e| {
+ normalize_sort_expr_with_equivalence_properties(e.clone(), eq_classes)
+ })
+ .collect::<Vec<_>>();
+ normalized_required_exprs
+ .iter()
+ .zip(normalized_provided_exprs.iter())
+ .all(|(order1, order2)| order1.eq(order2))
+ } else {
+ false
+ }
+}
+
+/// Util function to add SortExec above child
+/// preserving the original partitioning
+pub fn add_sort_above_child(
+ child: &Arc<dyn ExecutionPlan>,
+ sort_expr: Vec<PhysicalSortExpr>,
+) -> Result<Arc<dyn ExecutionPlan>> {
+ let new_child = if child.output_partitioning().partition_count() > 1 {
+ Arc::new(SortExec::new_with_partitioning(
+ sort_expr,
+ child.clone(),
+ true,
+ None,
+ )) as Arc<dyn ExecutionPlan>
+ } else {
+ Arc::new(SortExec::try_new(sort_expr, child.clone(), None)?)
+ as Arc<dyn ExecutionPlan>
+ };
+ Ok(new_child)
+}
diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs
index b29dc0cb8..1c36014f2 100644
--- a/datafusion/core/src/physical_plan/common.rs
+++ b/datafusion/core/src/physical_plan/common.rs
@@ -266,6 +266,22 @@ impl<T> Drop for AbortOnDropMany<T> {
}
}
+/// Transposes the given vector of vectors.
+pub fn transpose<T>(original: Vec<Vec<T>>) -> Vec<Vec<T>> {
+ match original.as_slice() {
+ [] => vec![],
+ [first, ..] => {
+ let mut result = (0..first.len()).map(|_| vec![]).collect::<Vec<_>>();
+ for row in original {
+ for (item, transposed_row) in row.into_iter().zip(&mut result) {
+ transposed_row.push(item);
+ }
+ }
+ result
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -332,6 +348,15 @@ mod tests {
assert_eq!(actual, expected);
Ok(())
}
+
+ #[test]
+ fn test_transpose() -> Result<()> {
+ let in_data = vec![vec![1, 2, 3], vec![4, 5, 6]];
+ let transposed = transpose(in_data);
+ let expected = vec![vec![1, 4], vec![2, 5], vec![3, 6]];
+ assert_eq!(expected, transposed);
+ Ok(())
+ }
}
/// Write in Arrow IPC format.
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 0a598be87..e16c518b6 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -583,7 +583,7 @@ impl DefaultPhysicalPlanner {
let physical_input_schema = input_exec.schema();
let sort_keys = sort_keys
.iter()
- .map(|e| match e {
+ .map(|(e, _)| match e {
Expr::Sort(expr::Sort {
expr,
asc,
diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs
index 9492fb749..3dc0c6d33 100644
--- a/datafusion/core/src/physical_plan/repartition.rs
+++ b/datafusion/core/src/physical_plan/repartition.rs
@@ -289,7 +289,16 @@ impl ExecutionPlan for RepartitionExec {
}
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
- None
+ if self.maintains_input_order() {
+ self.input().output_ordering()
+ } else {
+ None
+ }
+ }
+
+ fn maintains_input_order(&self) -> bool {
+ // We preserve ordering when input partitioning is 1
+ self.input().output_partitioning().partition_count() <= 1
}
fn equivalence_properties(&self) -> EquivalenceProperties {
diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs
index 914e3e71d..d1ea0af69 100644
--- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs
+++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs
@@ -19,6 +19,7 @@
use crate::error::Result;
use crate::execution::context::TaskContext;
+use crate::physical_plan::common::transpose;
use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::physical_plan::metrics::{
BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
@@ -28,19 +29,23 @@ use crate::physical_plan::{
ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream,
SendableRecordBatchStream, Statistics, WindowExpr,
};
-use arrow::compute::concat_batches;
+use arrow::compute::{
+ concat, concat_batches, lexicographical_partition_ranges, SortColumn,
+};
use arrow::{
array::ArrayRef,
datatypes::{Schema, SchemaRef},
error::{ArrowError, Result as ArrowResult},
record_batch::RecordBatch,
};
+use datafusion_common::DataFusionError;
use datafusion_physical_expr::rewrite::TreeNodeRewritable;
use datafusion_physical_expr::EquivalentClass;
use futures::stream::Stream;
use futures::{ready, StreamExt};
use log::debug;
use std::any::Any;
+use std::ops::Range;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
@@ -131,6 +136,28 @@ impl WindowAggExec {
pub fn input_schema(&self) -> SchemaRef {
self.input_schema.clone()
}
+
+ /// Return the output sort order of partition keys: For example
+ /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a
+ // We are sure that partition by columns are always at the beginning of sort_keys
+ // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely
+ // to calculate partition separation points
+ pub fn partition_by_sort_keys(&self) -> Result<Vec<PhysicalSortExpr>> {
+ let mut result = vec![];
+ // All window exprs have the same partition by, so we just use the first one:
+ let partition_by = self.window_expr()[0].partition_by();
+ let sort_keys = self.sort_keys.as_deref().unwrap_or(&[]);
+ for item in partition_by {
+ if let Some(a) = sort_keys.iter().find(|&e| e.expr.eq(item)) {
+ result.push(a.clone());
+ } else {
+ return Err(DataFusionError::Execution(
+ "Partition key not found in sort keys".to_string(),
+ ));
+ }
+ }
+ Ok(result)
+ }
}
impl ExecutionPlan for WindowAggExec {
@@ -253,6 +280,7 @@ impl ExecutionPlan for WindowAggExec {
self.window_expr.clone(),
input,
BaselineMetrics::new(&self.metrics, partition),
+ self.partition_by_sort_keys()?,
));
Ok(stream)
}
@@ -337,6 +365,7 @@ pub struct WindowAggStream {
batches: Vec<RecordBatch>,
finished: bool,
window_expr: Vec<Arc<dyn WindowExpr>>,
+ partition_by_sort_keys: Vec<PhysicalSortExpr>,
baseline_metrics: BaselineMetrics,
}
@@ -347,6 +376,7 @@ impl WindowAggStream {
window_expr: Vec<Arc<dyn WindowExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
+ partition_by_sort_keys: Vec<PhysicalSortExpr>,
) -> Self {
Self {
schema,
@@ -355,6 +385,7 @@ impl WindowAggStream {
finished: false,
window_expr,
baseline_metrics,
+ partition_by_sort_keys,
}
}
@@ -368,9 +399,32 @@ impl WindowAggStream {
let batch = concat_batches(&self.input.schema(), &self.batches)?;
- // calculate window cols
- let mut columns = compute_window_aggregates(&self.window_expr, &batch)
- .map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
+ let partition_by_sort_keys = self
+ .partition_by_sort_keys
+ .iter()
+ .map(|elem| elem.evaluate_to_sort_column(&batch))
+ .collect::<Result<Vec<_>>>()?;
+ let partition_points =
+ self.evaluate_partition_points(batch.num_rows(), &partition_by_sort_keys)?;
+
+ let mut partition_results = vec![];
+ // Calculate window cols
+ for partition_point in partition_points {
+ let length = partition_point.end - partition_point.start;
+ partition_results.push(
+ compute_window_aggregates(
+ &self.window_expr,
+ &batch.slice(partition_point.start, length),
+ )
+ .map_err(|e| ArrowError::ExternalError(Box::new(e)))?,
+ )
+ }
+ let mut columns = transpose(partition_results)
+ .iter()
+ .map(|elems| concat(&elems.iter().map(|x| x.as_ref()).collect::<Vec<_>>()))
+ .collect::<Vec<_>>()
+ .into_iter()
+ .collect::<ArrowResult<Vec<ArrayRef>>>()?;
// combine with the original cols
// note the setup of window aggregates is that they newly calculated window
@@ -378,6 +432,25 @@ impl WindowAggStream {
columns.extend_from_slice(batch.columns());
RecordBatch::try_new(self.schema.clone(), columns)
}
+
+ /// Evaluates the partition points given the sort columns. If the sort columns are
+ /// empty, then the result will be a single element vector spanning the entire batch.
+ fn evaluate_partition_points(
+ &self,
+ num_rows: usize,
+ partition_columns: &[SortColumn],
+ ) -> Result<Vec<Range<usize>>> {
+ Ok(if partition_columns.is_empty() {
+ vec![Range {
+ start: 0,
+ end: num_rows,
+ }]
+ } else {
+ lexicographical_partition_ranges(partition_columns)
+ .map_err(DataFusionError::ArrowError)?
+ .collect::<Vec<_>>()
+ })
+ }
}
impl Stream for WindowAggStream {
diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs
index 62aeb0255..90fd91164 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -61,11 +61,6 @@ async fn explain_analyze_baseline_metrics() {
"AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]",
"metrics=[output_rows=5, elapsed_compute="
);
- assert_metrics!(
- &formatted,
- "SortExec: [c1@0 ASC NULLS LAST]",
- "metrics=[output_rows=5, elapsed_compute="
- );
assert_metrics!(
&formatted,
"FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434",
diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs
index c9ef64212..41278e120 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -1748,17 +1748,20 @@ async fn test_window_partition_by_order_by() -> Result<()> {
let msg = format!("Creating logical plan for '{}'", sql);
let dataframe = ctx.sql(sql).await.expect(&msg);
- let physical_plan = dataframe.create_physical_plan().await.unwrap();
+ let physical_plan = dataframe.create_physical_plan().await?;
let formatted = displayable(physical_plan.as_ref()).indent().to_string();
- // Only 1 SortExec was added
let expected = {
vec![
- "ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as COUNT(UInt8(1))]",
- " WindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: [...]
- " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]",
+ "ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as COUNT(UInt8(1))]",
+ " WindowAggExec: wdw=[COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]",
+ " SortExec: [c1@1 ASC NULLS LAST,c2@2 ASC NULLS LAST]",
" CoalesceBatchesExec: target_batch_size=4096",
- " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 2)",
- " RepartitionExec: partitioning=RoundRobinBatch(2)",
+ " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 1 }], 2)",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]",
+ " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }, Column { name: \"c2\", index: 1 }], 2)",
+ " RepartitionExec: partitioning=RoundRobinBatch(2)",
]
};
@@ -1772,3 +1775,568 @@ async fn test_window_partition_by_order_by() -> Result<()> {
);
Ok(())
}
+
+#[tokio::test]
+async fn test_window_agg_sort_reversed_plan() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c9,
+ SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1,
+ SUM(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2
+ FROM aggregate_test_100
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ // Only 1 SortExec was added
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[c9@2 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]",
+ " GlobalLimitExec: skip=0, fetch=5",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
+ " SortExec: [c9@0 DESC]",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+------------+-------------+-------------+",
+ "| c9 | sum1 | sum2 |",
+ "+------------+-------------+-------------+",
+ "| 4268716378 | 8498370520 | 24997484146 |",
+ "| 4229654142 | 12714811027 | 29012926487 |",
+ "| 4216440507 | 16858984380 | 28743001064 |",
+ "| 4144173353 | 20935849039 | 28472563256 |",
+ "| 4076864659 | 24997484146 | 28118515915 |",
+ "+------------+-------------+-------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_window_agg_sort_reversed_plan_builtin() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c9,
+ FIRST_VALUE(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as fv1,
+ FIRST_VALUE(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as fv2,
+ LAG(c9, 2, 10101) OVER(ORDER BY c9 ASC) as lag1,
+ LAG(c9, 2, 10101) OVER(ORDER BY c9 DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lag2,
+ LEAD(c9, 2, 10101) OVER(ORDER BY c9 ASC) as lead1,
+ LEAD(c9, 2, 10101) OVER(ORDER BY c9 DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lead2
+ FROM aggregate_test_100
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ // Only 1 SortExec was added
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[c9@6 as c9, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as fv1, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as lag1, LAG(aggregate_tes [...]
+ " GlobalLimitExec: skip=0, fetch=5",
+ " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, di [...]
+ " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, [...]
+ " SortExec: [c9@0 DESC]",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+------------+------------+------------+------------+------------+------------+------------+",
+ "| c9 | fv1 | fv2 | lag1 | lag2 | lead1 | lead2 |",
+ "+------------+------------+------------+------------+------------+------------+------------+",
+ "| 4268716378 | 4229654142 | 4268716378 | 4216440507 | 10101 | 10101 | 4216440507 |",
+ "| 4229654142 | 4216440507 | 4268716378 | 4144173353 | 10101 | 10101 | 4144173353 |",
+ "| 4216440507 | 4144173353 | 4229654142 | 4076864659 | 4268716378 | 4268716378 | 4076864659 |",
+ "| 4144173353 | 4076864659 | 4216440507 | 4061635107 | 4229654142 | 4229654142 | 4061635107 |",
+ "| 4076864659 | 4061635107 | 4144173353 | 4015442341 | 4216440507 | 4216440507 | 4015442341 |",
+ "+------------+------------+------------+------------+------------+------------+------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_window_agg_sort_non_reversed_plan() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c9,
+ ROW_NUMBER() OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as rn1,
+ ROW_NUMBER() OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as rn2
+ FROM aggregate_test_100
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ // We cannot reverse each window function (ROW_NUMBER is not reversible)
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[c9@2 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as rn1, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2]",
+ " GlobalLimitExec: skip=0, fetch=5",
+ " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
+ " SortExec: [c9@1 ASC NULLS LAST]",
+ " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
+ " SortExec: [c9@0 DESC]",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----------+-----+-----+",
+ "| c9 | rn1 | rn2 |",
+ "+-----------+-----+-----+",
+ "| 28774375 | 1 | 100 |",
+ "| 63044568 | 2 | 99 |",
+ "| 141047417 | 3 | 98 |",
+ "| 141680161 | 4 | 97 |",
+ "| 145294611 | 5 | 96 |",
+ "+-----------+-----+-----+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c9,
+ SUM(c9) OVER(ORDER BY c9 ASC, c1 ASC, c2 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1,
+ SUM(c9) OVER(ORDER BY c9 DESC, c1 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2,
+ ROW_NUMBER() OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as rn2
+ FROM aggregate_test_100
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ // We cannot reverse each window function (ROW_NUMBER is not reversible)
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[c9@5 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum2, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWE [...]
+ " GlobalLimitExec: skip=0, fetch=5",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
+ " SortExec: [c9@4 ASC NULLS LAST,c1@2 ASC NULLS LAST,c2@3 ASC NULLS LAST]",
+ " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
+ " SortExec: [c9@2 DESC,c1@0 DESC]",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----------+------------+-----------+-----+",
+ "| c9 | sum1 | sum2 | rn2 |",
+ "+-----------+------------+-----------+-----+",
+ "| 28774375 | 745354217 | 91818943 | 100 |",
+ "| 63044568 | 988558066 | 232866360 | 99 |",
+ "| 141047417 | 1285934966 | 374546521 | 98 |",
+ "| 141680161 | 1654839259 | 519841132 | 97 |",
+ "| 145294611 | 1980231675 | 745354217 | 96 |",
+ "+-----------+------------+-----------+-----+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_window_agg_complex_plan() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_null_cases_csv(&ctx).await?;
+ let sql = "SELECT
+ SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as a,
+ SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as b,
+ SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as c,
+ SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as d,
+ SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as e,
+ SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as f,
+ SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as g,
+ SUM(c1) OVER (ORDER BY c3) as h,
+ SUM(c1) OVER (ORDER BY c3 DESC) as i,
+ SUM(c1) OVER (ORDER BY c3 NULLS first) as j,
+ SUM(c1) OVER (ORDER BY c3 DESC NULLS first) as k,
+ SUM(c1) OVER (ORDER BY c3 DESC NULLS last) as l,
+ SUM(c1) OVER (ORDER BY c3, c2) as m,
+ SUM(c1) OVER (ORDER BY c3, c1 DESC) as n,
+ SUM(c1) OVER (ORDER BY c3 DESC, c1) as o,
+ SUM(c1) OVER (ORDER BY c3, c1 NULLs first) as p,
+ SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as a1,
+ SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as b1,
+ SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as c1,
+ SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as d1,
+ SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as e1,
+ SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as f1,
+ SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as g1,
+ SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as h1,
+ SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as j1,
+ SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as k1,
+ SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as l1,
+ SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as m1,
+ SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as n1,
+ SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as o1,
+ SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as h11,
+ SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as j11,
+ SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as k11,
+ SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as l11,
+ SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as m11,
+ SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as n11,
+ SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as o11
+ FROM null_cases
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ // Unnecessary SortExecs are removed
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as a, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as b, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@15 as c, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@7 as d, SUM(null_ [...]
+ " GlobalLimitExec: skip=0, fetch=5",
+ " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preced [...]
+ " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]",
+ " SortExec: [c3@17 ASC NULLS LAST,c2@16 ASC NULLS LAST]",
+ " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]",
+ " SortExec: [c3@16 ASC NULLS LAST,c1@14 ASC]",
+ " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]",
+ " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_b [...]
+ " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start [...]
+ " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, sta [...]
+ " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]",
+ " SortExec: [c3@2 DESC,c1@0 ASC NULLS LAST]",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_window_agg_sort_orderby_reversed_partitionby_plan() -> Result<()> {
+ let config = SessionConfig::new().with_repartition_windows(false);
+ let ctx = SessionContext::with_config(config);
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c9,
+ SUM(c9) OVER(ORDER BY c1, c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1,
+ SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2
+ FROM aggregate_test_100
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ // Only 1 SortExec was added
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum2]",
+ " GlobalLimitExec: skip=0, fetch=5",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
+ " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+------------+-------------+-------------+",
+ "| c9 | sum1 | sum2 |",
+ "+------------+-------------+-------------+",
+ "| 4015442341 | 21907044499 | 21907044499 |",
+ "| 3998790955 | 24576419362 | 24576419362 |",
+ "| 3959216334 | 23063303501 | 23063303501 |",
+ "| 3717551163 | 21560567246 | 21560567246 |",
+ "| 3276123488 | 19815386638 | 19815386638 |",
+ "+------------+-------------+-------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_window_agg_sort_partitionby_reversed_plan() -> Result<()> {
+ let config = SessionConfig::new().with_repartition_windows(false);
+ let ctx = SessionContext::with_config(config);
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT
+ c9,
+ SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1,
+ SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2
+ FROM aggregate_test_100
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ // Only 1 SortExec was added
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]",
+ " GlobalLimitExec: skip=0, fetch=5",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]",
+ " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+------------+-------------+-------------+",
+ "| c9 | sum1 | sum2 |",
+ "+------------+-------------+-------------+",
+ "| 4015442341 | 8014233296 | 21907044499 |",
+ "| 3998790955 | 11973449630 | 24576419362 |",
+ "| 3959216334 | 15691000793 | 23063303501 |",
+ "| 3717551163 | 18967124281 | 21560567246 |",
+ "| 3276123488 | 21907044499 | 19815386638 |",
+ "+------------+-------------+-------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_window_agg_sort_orderby_reversed_binary_expr() -> Result<()> {
+ let config = SessionConfig::new().with_repartition_windows(false);
+ let ctx = SessionContext::with_config(config);
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT c3,
+ SUM(c9) OVER(ORDER BY c3+c4 DESC, c9 DESC, c2 ASC) as sum1,
+ SUM(c9) OVER(ORDER BY c3+c4 ASC, c9 ASC ) as sum2
+ FROM aggregate_test_100
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ // Only 1 SortExec was added
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[c3@3 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@0 as sum2]",
+ " GlobalLimitExec: skip=0, fetch=5",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: \"SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ord [...]
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: \"SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED P [...]
+ " SortExec: [CAST(c3@1 AS Int16) + c4@2 DESC,c9@3 DESC,c2@0 ASC NULLS LAST]",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----+-------------+--------------+",
+ "| c3 | sum1 | sum2 |",
+ "+-----+-------------+--------------+",
+ "| -86 | 2861911482 | 222089770060 |",
+ "| 13 | 5075947208 | 219227858578 |",
+ "| 125 | 8701233618 | 217013822852 |",
+ "| 123 | 11293564174 | 213388536442 |",
+ "| 97 | 14767488750 | 210796205886 |",
+ "+-----+-------------+--------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_remove_unnecessary_sort_in_sub_query() -> Result<()> {
+ let config = SessionConfig::new()
+ .with_target_partitions(8)
+ .with_repartition_windows(true);
+ let ctx = SessionContext::with_config(config);
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT count(*) as global_count FROM
+ (SELECT count(*), c1
+ FROM aggregate_test_100
+ WHERE c13 != 'C2GT5KVyOPZpgKVl110TyZO0NcJ434'
+ GROUP BY c1
+ ORDER BY c1 ) AS a ";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ // Unnecessary Sort in the sub query is removed
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[COUNT(UInt8(1))@0 as global_count]",
+ " AggregateExec: mode=Final, gby=[], aggr=[COUNT(UInt8(1))]",
+ " CoalescePartitionsExec",
+ " AggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))]",
+ " RepartitionExec: partitioning=RoundRobinBatch(8)",
+ " CoalescePartitionsExec",
+ " AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(UInt8(1))]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 8)",
+ " AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[COUNT(UInt8(1))]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434",
+ " RepartitionExec: partitioning=RoundRobinBatch(8)",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+--------------+",
+ "| global_count |",
+ "+--------------+",
+ "| 5 |",
+ "+--------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn test_window_agg_sort_orderby_reversed_partitionby_reversed_plan() -> Result<()> {
+ let config = SessionConfig::new().with_repartition_windows(false);
+ let ctx = SessionContext::with_config(config);
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT c3,
+ SUM(c9) OVER(ORDER BY c3 DESC, c9 DESC, c2 ASC) as sum1,
+ SUM(c9) OVER(PARTITION BY c3 ORDER BY c9 DESC ) as sum2
+ FROM aggregate_test_100
+ LIMIT 5";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let dataframe = ctx.sql(sql).await.expect(&msg);
+ let physical_plan = dataframe.create_physical_plan().await?;
+ let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+ // Only 1 SortExec was added
+ let expected = {
+ vec![
+ "ProjectionExec: expr=[c3@3 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@0 as sum2]",
+ " GlobalLimitExec: skip=0, fetch=5",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
+ " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow }]",
+ " SortExec: [c3@1 DESC,c9@2 DESC,c2@0 ASC NULLS LAST]",
+ ]
+ };
+
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ let actual_len = actual.len();
+ let actual_trim_last = &actual[..actual_len - 1];
+ assert_eq!(
+ expected, actual_trim_last,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----+-------------+------------+",
+ "| c3 | sum1 | sum2 |",
+ "+-----+-------------+------------+",
+ "| 125 | 3625286410 | 3625286410 |",
+ "| 123 | 7192027599 | 3566741189 |",
+ "| 123 | 9784358155 | 6159071745 |",
+ "| 122 | 13845993262 | 4061635107 |",
+ "| 120 | 16676974334 | 2830981072 |",
+ "+-----+-------------+------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index bf2a1d001..eeb3215c4 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -257,7 +257,7 @@ impl LogicalPlanBuilder {
// The sort_by() implementation here is a stable sort.
// Note that by this rule if there's an empty over, it'll be at the top level
groups.sort_by(|(key_a, _), (key_b, _)| {
- for (first, second) in key_a.iter().zip(key_b.iter()) {
+ for ((first, _), (second, _)) in key_a.iter().zip(key_b.iter()) {
let key_ordering = compare_sort_expr(first, second, plan.schema());
match key_ordering {
Ordering::Less => {
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 3ee36de17..ca06dfdb4 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -204,7 +204,9 @@ pub fn expand_qualified_wildcard(
expand_wildcard(&qualifier_schema, plan)
}
-type WindowSortKey = Vec<Expr>;
+/// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)")
+/// if bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column
+type WindowSortKey = Vec<(Expr, bool)>;
/// Generate a sort key for a given window expr's partition_by and order_bu expr
pub fn generate_sort_key(
@@ -224,6 +226,7 @@ pub fn generate_sort_key(
.collect::<Result<Vec<_>>>()?;
let mut final_sort_keys = vec![];
+ let mut is_partition_flag = vec![];
partition_by.iter().for_each(|e| {
// By default, create sort key with ASC is true and NULLS LAST to be consistent with
// PostgreSQL's rule: https://www.postgresql.org/docs/current/queries-order.html
@@ -232,18 +235,26 @@ pub fn generate_sort_key(
let order_by_key = &order_by[pos];
if !final_sort_keys.contains(order_by_key) {
final_sort_keys.push(order_by_key.clone());
+ is_partition_flag.push(true);
}
} else if !final_sort_keys.contains(&e) {
final_sort_keys.push(e);
+ is_partition_flag.push(true);
}
});
order_by.iter().for_each(|e| {
if !final_sort_keys.contains(e) {
final_sort_keys.push(e.clone());
+ is_partition_flag.push(false);
}
});
- Ok(final_sort_keys)
+ let res = final_sort_keys
+ .into_iter()
+ .zip(is_partition_flag)
+ .map(|(lhs, rhs)| (lhs, rhs))
+ .collect::<Vec<_>>();
+ Ok(res)
}
/// Compare the sort expr as PostgreSQL's common_prefix_cmp():
@@ -1043,9 +1054,13 @@ mod tests {
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs)?;
- let key1 = vec![age_asc.clone(), name_desc.clone()];
+ let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
let key2 = vec![];
- let key3 = vec![name_desc, age_asc, created_at_desc];
+ let key3 = vec![
+ (name_desc, false),
+ (age_asc, false),
+ (created_at_desc, false),
+ ];
let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![
(key1, vec![&max1, &min3]),
@@ -1112,21 +1127,30 @@ mod tests {
];
let expected = vec![
- Expr::Sort(Sort {
- expr: Box::new(col("age")),
- asc: asc_,
- nulls_first: nulls_first_,
- }),
- Expr::Sort(Sort {
- expr: Box::new(col("name")),
- asc: asc_,
- nulls_first: nulls_first_,
- }),
- Expr::Sort(Sort {
- expr: Box::new(col("created_at")),
- asc: true,
- nulls_first: false,
- }),
+ (
+ Expr::Sort(Sort {
+ expr: Box::new(col("age")),
+ asc: asc_,
+ nulls_first: nulls_first_,
+ }),
+ true,
+ ),
+ (
+ Expr::Sort(Sort {
+ expr: Box::new(col("name")),
+ asc: asc_,
+ nulls_first: nulls_first_,
+ }),
+ true,
+ ),
+ (
+ Expr::Sort(Sort {
+ expr: Box::new(col("created_at")),
+ asc: true,
+ nulls_first: false,
+ }),
+ true,
+ ),
];
let result = generate_sort_key(partition_by, order_by)?;
assert_eq!(expected, result);
diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs
index 62c7c57d4..100ea8e1d 100644
--- a/datafusion/expr/src/window_frame.rs
+++ b/datafusion/expr/src/window_frame.rs
@@ -113,6 +113,35 @@ impl WindowFrame {
}
}
}
+
+ /// Get reversed window frame. For example
+ /// `3 ROWS PRECEDING AND 2 ROWS FOLLOWING` -->
+ /// `2 ROWS PRECEDING AND 3 ROWS FOLLOWING`
+ pub fn reverse(&self) -> Self {
+ let start_bound = match &self.end_bound {
+ WindowFrameBound::Preceding(elem) => {
+ WindowFrameBound::Following(elem.clone())
+ }
+ WindowFrameBound::Following(elem) => {
+ WindowFrameBound::Preceding(elem.clone())
+ }
+ WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow,
+ };
+ let end_bound = match &self.start_bound {
+ WindowFrameBound::Preceding(elem) => {
+ WindowFrameBound::Following(elem.clone())
+ }
+ WindowFrameBound::Following(elem) => {
+ WindowFrameBound::Preceding(elem.clone())
+ }
+ WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow,
+ };
+ WindowFrame {
+ units: self.units,
+ start_bound,
+ end_bound,
+ }
+ }
}
/// There are five ways to describe starting and ending frame boundaries:
diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs
index 6c43344db..813952117 100644
--- a/datafusion/physical-expr/src/aggregate/count.rs
+++ b/datafusion/physical-expr/src/aggregate/count.rs
@@ -36,7 +36,7 @@ use crate::expressions::format_state_name;
/// COUNT aggregate expression
/// Returns the amount of non-null values of the given expression.
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Count {
name: String,
data_type: DataType,
@@ -105,6 +105,10 @@ impl AggregateExpr for Count {
Ok(Box::new(CountRowAccumulator::new(start_index)))
}
+ fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
+ Some(Arc::new(self.clone()))
+ }
+
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(CountAccumulator::new()))
}
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs
index 436a23396..947336596 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -103,6 +103,14 @@ pub trait AggregateExpr: Send + Sync + Debug {
)))
}
+ /// Construct an expression that calculates the aggregate in reverse.
+ /// Typically the "reverse" expression is itself (e.g. SUM, COUNT).
+ /// For aggregates that do not support calculation in reverse,
+ /// returns None (which is the default value).
+ fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
+ None
+ }
+
/// Creates accumulator implementation that supports retract
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Err(DataFusionError::NotImplemented(format!(
diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs
index 8d2620296..c2d54c40e 100644
--- a/datafusion/physical-expr/src/aggregate/sum.rs
+++ b/datafusion/physical-expr/src/aggregate/sum.rs
@@ -44,7 +44,7 @@ use arrow::compute::cast;
use datafusion_row::accessor::RowAccessor;
/// SUM aggregate expression
-#[derive(Debug)]
+#[derive(Debug, Clone)]
pub struct Sum {
name: String,
data_type: DataType,
@@ -123,6 +123,10 @@ impl AggregateExpr for Sum {
)))
}
+ fn reverse_expr(&self) -> Option<Arc<dyn AggregateExpr>> {
+ Some(Arc::new(self.clone()))
+ }
+
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(SumAccumulator::try_new(&self.data_type)?))
}
diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs
index c42f7ff55..5c46f38f2 100644
--- a/datafusion/physical-expr/src/window/aggregate.rs
+++ b/datafusion/physical-expr/src/window/aggregate.rs
@@ -19,6 +19,7 @@
use std::any::Any;
use std::iter::IntoIterator;
+use std::ops::Range;
use std::sync::Arc;
use arrow::array::Array;
@@ -30,6 +31,8 @@ use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::WindowFrame;
+use crate::window::window_expr::reverse_order_bys;
+use crate::window::SlidingAggregateWindowExpr;
use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
use crate::{window::WindowExpr, AggregateExpr};
@@ -89,49 +92,41 @@ impl WindowExpr for AggregateWindowExpr {
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
- let partition_columns = self.partition_columns(batch)?;
- let partition_points =
- self.evaluate_partition_points(batch.num_rows(), &partition_columns)?;
let sort_options: Vec<SortOptions> =
self.order_by.iter().map(|o| o.options).collect();
let mut row_wise_results: Vec<ScalarValue> = vec![];
- for partition_range in &partition_points {
- let mut accumulator = self.aggregate.create_accumulator()?;
- let length = partition_range.end - partition_range.start;
- let (values, order_bys) =
- self.get_values_orderbys(&batch.slice(partition_range.start, length))?;
-
- let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
- let mut last_range: (usize, usize) = (0, 0);
-
- // We iterate on each row to perform a running calculation.
- // First, cur_range is calculated, then it is compared with last_range.
- for i in 0..length {
- let cur_range = window_frame_ctx.calculate_range(
- &order_bys,
- &sort_options,
- length,
- i,
- )?;
- let value = if cur_range.0 == cur_range.1 {
- // We produce None if the window is empty.
- ScalarValue::try_from(self.aggregate.field()?.data_type())?
- } else {
- // Accumulate any new rows that have entered the window:
- let update_bound = cur_range.1 - last_range.1;
- if update_bound > 0 {
- let update: Vec<ArrayRef> = values
- .iter()
- .map(|v| v.slice(last_range.1, update_bound))
- .collect();
- accumulator.update_batch(&update)?
- }
- accumulator.evaluate()?
- };
- row_wise_results.push(value);
- last_range = cur_range;
- }
+
+ let mut accumulator = self.aggregate.create_accumulator()?;
+ let length = batch.num_rows();
+ let (values, order_bys) = self.get_values_orderbys(batch)?;
+
+ let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
+ let mut last_range = Range { start: 0, end: 0 };
+
+ // We iterate on each row to perform a running calculation.
+ // First, cur_range is calculated, then it is compared with last_range.
+ for i in 0..length {
+ let cur_range =
+ window_frame_ctx.calculate_range(&order_bys, &sort_options, length, i)?;
+ let value = if cur_range.end == cur_range.start {
+ // We produce None if the window is empty.
+ ScalarValue::try_from(self.aggregate.field()?.data_type())?
+ } else {
+ // Accumulate any new rows that have entered the window:
+ let update_bound = cur_range.end - last_range.end;
+ if update_bound > 0 {
+ let update: Vec<ArrayRef> = values
+ .iter()
+ .map(|v| v.slice(last_range.end, update_bound))
+ .collect();
+ accumulator.update_batch(&update)?
+ }
+ accumulator.evaluate()?
+ };
+ row_wise_results.push(value);
+ last_range = cur_range;
}
+
ScalarValue::iter_to_array(row_wise_results.into_iter())
}
@@ -146,4 +141,25 @@ impl WindowExpr for AggregateWindowExpr {
fn get_window_frame(&self) -> &Arc<WindowFrame> {
&self.window_frame
}
+
+ fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
+ self.aggregate.reverse_expr().map(|reverse_expr| {
+ let reverse_window_frame = self.window_frame.reverse();
+ if reverse_window_frame.start_bound.is_unbounded() {
+ Arc::new(AggregateWindowExpr::new(
+ reverse_expr,
+ &self.partition_by.clone(),
+ &reverse_order_bys(&self.order_by),
+ Arc::new(self.window_frame.reverse()),
+ )) as _
+ } else {
+ Arc::new(SlidingAggregateWindowExpr::new(
+ reverse_expr,
+ &self.partition_by.clone(),
+ &reverse_order_bys(&self.order_by),
+ Arc::new(self.window_frame.reverse()),
+ )) as _
+ }
+ })
+ }
}
diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs
index 95bf01608..9804432b2 100644
--- a/datafusion/physical-expr/src/window/built_in.rs
+++ b/datafusion/physical-expr/src/window/built_in.rs
@@ -20,15 +20,15 @@
use super::window_frame_state::WindowFrameContext;
use super::BuiltInWindowFunctionExpr;
use super::WindowExpr;
+use crate::window::window_expr::reverse_order_bys;
use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
-use arrow::compute::{concat, SortOptions};
+use arrow::compute::SortOptions;
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, datatypes::Field};
-use datafusion_common::DataFusionError;
use datafusion_common::Result;
+use datafusion_common::ScalarValue;
use datafusion_expr::WindowFrame;
use std::any::Any;
-use std::ops::Range;
use std::sync::Arc;
/// A window expr that takes the form of a built in window function
@@ -91,50 +91,49 @@ impl WindowExpr for BuiltInWindowExpr {
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
let evaluator = self.expr.create_evaluator()?;
let num_rows = batch.num_rows();
- let partition_columns = self.partition_columns(batch)?;
- let partition_points =
- self.evaluate_partition_points(num_rows, &partition_columns)?;
-
- let results = if evaluator.uses_window_frame() {
+ if evaluator.uses_window_frame() {
let sort_options: Vec<SortOptions> =
self.order_by.iter().map(|o| o.options).collect();
let mut row_wise_results = vec![];
- for partition_range in &partition_points {
- let length = partition_range.end - partition_range.start;
- let (values, order_bys) = self
- .get_values_orderbys(&batch.slice(partition_range.start, length))?;
- let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
- // We iterate on each row to calculate window frame range and and window function result
- for idx in 0..length {
- let range = window_frame_ctx.calculate_range(
- &order_bys,
- &sort_options,
- num_rows,
- idx,
- )?;
- let range = Range {
- start: range.0,
- end: range.1,
- };
- let value = evaluator.evaluate_inside_range(&values, range)?;
- row_wise_results.push(value.to_array());
- }
+
+ let length = batch.num_rows();
+ let (values, order_bys) = self.get_values_orderbys(batch)?;
+ let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
+ // We iterate on each row to calculate window frame range and and window function result
+ for idx in 0..length {
+ let range = window_frame_ctx.calculate_range(
+ &order_bys,
+ &sort_options,
+ num_rows,
+ idx,
+ )?;
+ let value = evaluator.evaluate_inside_range(&values, range)?;
+ row_wise_results.push(value);
}
- row_wise_results
+ ScalarValue::iter_to_array(row_wise_results.into_iter())
} else if evaluator.include_rank() {
let columns = self.sort_columns(batch)?;
let sort_partition_points =
self.evaluate_partition_points(num_rows, &columns)?;
- evaluator.evaluate_with_rank(partition_points, sort_partition_points)?
+ evaluator.evaluate_with_rank(num_rows, &sort_partition_points)
} else {
let (values, _) = self.get_values_orderbys(batch)?;
- evaluator.evaluate(&values, partition_points)?
- };
- let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
- concat(&results).map_err(DataFusionError::ArrowError)
+ evaluator.evaluate(&values, num_rows)
+ }
}
fn get_window_frame(&self) -> &Arc<WindowFrame> {
&self.window_frame
}
+
+ fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
+ self.expr.reverse_expr().map(|reverse_expr| {
+ Arc::new(BuiltInWindowExpr::new(
+ reverse_expr,
+ &self.partition_by.clone(),
+ &reverse_order_bys(&self.order_by),
+ Arc::new(self.window_frame.reverse()),
+ )) as _
+ })
+ }
}
diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs
index 7f7a27435..c358403fe 100644
--- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs
+++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs
@@ -58,4 +58,10 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug {
/// Create built-in window evaluator with a batch
fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>>;
+
+ /// Construct Reverse Expression that produces the same result
+ /// on a reversed window. For example `lead(10)` --> `lag(10)`
+ fn reverse_expr(&self) -> Option<Arc<dyn BuiltInWindowFunctionExpr>> {
+ None
+ }
}
diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs
index 4202058a3..45fe51178 100644
--- a/datafusion/physical-expr/src/window/cume_dist.rs
+++ b/datafusion/physical-expr/src/window/cume_dist.rs
@@ -73,19 +73,19 @@ impl PartitionEvaluator for CumeDistEvaluator {
true
}
- fn evaluate_partition_with_rank(
+ fn evaluate_with_rank(
&self,
- partition: Range<usize>,
+ num_rows: usize,
ranks_in_partition: &[Range<usize>],
) -> Result<ArrayRef> {
- let scaler = (partition.end - partition.start) as f64;
+ let scalar = num_rows as f64;
let result = Float64Array::from_iter_values(
ranks_in_partition
.iter()
.scan(0_u64, |acc, range| {
let len = range.end - range.start;
*acc += len as u64;
- let value: f64 = (*acc as f64) / scaler;
+ let value: f64 = (*acc as f64) / scalar;
let result = iter::repeat(value).take(len);
Some(result)
})
@@ -102,15 +102,14 @@ mod tests {
fn test_i32_result(
expr: &CumeDist,
- partition: Range<usize>,
+ num_rows: usize,
ranks: Vec<Range<usize>>,
expected: Vec<f64>,
) -> Result<()> {
let result = expr
.create_evaluator()?
- .evaluate_with_rank(vec![partition], ranks)?;
- assert_eq!(1, result.len());
- let result = as_float64_array(&result[0])?;
+ .evaluate_with_rank(num_rows, &ranks)?;
+ let result = as_float64_array(&result)?;
let result = result.values();
assert_eq!(expected, result);
Ok(())
@@ -121,19 +120,19 @@ mod tests {
let r = cume_dist("arr".into());
let expected = vec![0.0; 0];
- test_i32_result(&r, 0..0, vec![], expected)?;
+ test_i32_result(&r, 0, vec![], expected)?;
let expected = vec![1.0; 1];
- test_i32_result(&r, 0..1, vec![0..1], expected)?;
+ test_i32_result(&r, 1, vec![0..1], expected)?;
let expected = vec![1.0; 2];
- test_i32_result(&r, 0..2, vec![0..2], expected)?;
+ test_i32_result(&r, 2, vec![0..2], expected)?;
let expected = vec![0.5, 0.5, 1.0, 1.0];
- test_i32_result(&r, 0..4, vec![0..2, 2..4], expected)?;
+ test_i32_result(&r, 4, vec![0..2, 2..4], expected)?;
let expected = vec![0.25, 0.5, 0.75, 1.0];
- test_i32_result(&r, 0..4, vec![0..1, 1..2, 2..3, 3..4], expected)?;
+ test_i32_result(&r, 4, vec![0..1, 1..2, 2..3, 3..4], expected)?;
Ok(())
}
diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs
index c7fc73b9f..e18815c4c 100644
--- a/datafusion/physical-expr/src/window/lead_lag.rs
+++ b/datafusion/physical-expr/src/window/lead_lag.rs
@@ -28,7 +28,6 @@ use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use std::any::Any;
use std::ops::Neg;
-use std::ops::Range;
use std::sync::Arc;
/// window shift expression
@@ -107,6 +106,16 @@ impl BuiltInWindowFunctionExpr for WindowShift {
default_value: self.default_value.clone(),
}))
}
+
+ fn reverse_expr(&self) -> Option<Arc<dyn BuiltInWindowFunctionExpr>> {
+ Some(Arc::new(Self {
+ name: self.name.clone(),
+ data_type: self.data_type.clone(),
+ shift_offset: -self.shift_offset,
+ expr: self.expr.clone(),
+ default_value: self.default_value.clone(),
+ }))
+ }
}
pub(crate) struct WindowShiftEvaluator {
@@ -164,15 +173,10 @@ fn shift_with_default_value(
}
impl PartitionEvaluator for WindowShiftEvaluator {
- fn evaluate_partition(
- &self,
- values: &[ArrayRef],
- partition: Range<usize>,
- ) -> Result<ArrayRef> {
+ fn evaluate(&self, values: &[ArrayRef], _num_rows: usize) -> Result<ArrayRef> {
// LEAD, LAG window functions take single column, values will have size 1
let value = &values[0];
- let value = value.slice(partition.start, partition.end - partition.start);
- shift_with_default_value(&value, self.shift_offset, self.default_value.as_ref())
+ shift_with_default_value(value, self.shift_offset, self.default_value.as_ref())
}
}
@@ -191,9 +195,10 @@ mod tests {
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
let values = expr.evaluate_args(&batch)?;
- let result = expr.create_evaluator()?.evaluate(&values, vec![0..8])?;
- assert_eq!(1, result.len());
- let result = as_int32_array(&result[0])?;
+ let result = expr
+ .create_evaluator()?
+ .evaluate(&values, batch.num_rows())?;
+ let result = as_int32_array(&result)?;
assert_eq!(expected, *result);
Ok(())
}
diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs
index 63a2354c9..e998b4701 100644
--- a/datafusion/physical-expr/src/window/nth_value.rs
+++ b/datafusion/physical-expr/src/window/nth_value.rs
@@ -123,6 +123,20 @@ impl BuiltInWindowFunctionExpr for NthValue {
fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(NthValueEvaluator { kind: self.kind }))
}
+
+ fn reverse_expr(&self) -> Option<Arc<dyn BuiltInWindowFunctionExpr>> {
+ let reversed_kind = match self.kind {
+ NthValueKind::First => NthValueKind::Last,
+ NthValueKind::Last => NthValueKind::First,
+ NthValueKind::Nth(_) => return None,
+ };
+ Some(Arc::new(Self {
+ name: self.name.clone(),
+ expr: self.expr.clone(),
+ data_type: self.data_type.clone(),
+ kind: reversed_kind,
+ }))
+ }
}
/// Value evaluator for nth_value functions
diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs
index ed00c3c86..f5844eccc 100644
--- a/datafusion/physical-expr/src/window/ntile.rs
+++ b/datafusion/physical-expr/src/window/ntile.rs
@@ -26,7 +26,6 @@ use arrow::datatypes::Field;
use arrow_schema::DataType;
use datafusion_common::Result;
use std::any::Any;
-use std::ops::Range;
use std::sync::Arc;
#[derive(Debug)]
@@ -70,12 +69,8 @@ pub(crate) struct NtileEvaluator {
}
impl PartitionEvaluator for NtileEvaluator {
- fn evaluate_partition(
- &self,
- _values: &[ArrayRef],
- partition: Range<usize>,
- ) -> Result<ArrayRef> {
- let num_rows = (partition.end - partition.start) as u64;
+ fn evaluate(&self, _values: &[ArrayRef], num_rows: usize) -> Result<ArrayRef> {
+ let num_rows = num_rows as u64;
let mut vec: Vec<u64> = Vec::new();
for i in 0..num_rows {
let res = i * self.n / num_rows;
diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs
index 1608758d6..86500441d 100644
--- a/datafusion/physical-expr/src/window/partition_evaluator.rs
+++ b/datafusion/physical-expr/src/window/partition_evaluator.rs
@@ -22,23 +22,6 @@ use datafusion_common::Result;
use datafusion_common::{DataFusionError, ScalarValue};
use std::ops::Range;
-/// Given a partition range, and the full list of sort partition points, given that the sort
-/// partition points are sorted using [partition columns..., order columns...], the split
-/// boundaries would align (what's sorted on [partition columns...] would definitely be sorted
-/// on finer columns), so this will use binary search to find ranges that are within the
-/// partition range and return the valid slice.
-pub(crate) fn find_ranges_in_range<'a>(
- partition_range: &Range<usize>,
- sort_partition_points: &'a [Range<usize>],
-) -> &'a [Range<usize>] {
- let start_idx = sort_partition_points
- .partition_point(|sort_range| sort_range.start < partition_range.start);
- let end_idx = start_idx
- + sort_partition_points[start_idx..]
- .partition_point(|sort_range| sort_range.end <= partition_range.end);
- &sort_partition_points[start_idx..end_idx]
-}
-
/// Partition evaluator
pub trait PartitionEvaluator {
/// Whether the evaluator should be evaluated with rank
@@ -50,49 +33,17 @@ pub trait PartitionEvaluator {
false
}
- /// evaluate the partition evaluator against the partitions
- fn evaluate(
- &self,
- values: &[ArrayRef],
- partition_points: Vec<Range<usize>>,
- ) -> Result<Vec<ArrayRef>> {
- partition_points
- .into_iter()
- .map(|partition| self.evaluate_partition(values, partition))
- .collect()
- }
-
- /// evaluate the partition evaluator against the partitions with rank information
- fn evaluate_with_rank(
- &self,
- partition_points: Vec<Range<usize>>,
- sort_partition_points: Vec<Range<usize>>,
- ) -> Result<Vec<ArrayRef>> {
- partition_points
- .into_iter()
- .map(|partition| {
- let ranks_in_partition =
- find_ranges_in_range(&partition, &sort_partition_points);
- self.evaluate_partition_with_rank(partition, ranks_in_partition)
- })
- .collect()
- }
-
/// evaluate the partition evaluator against the partition
- fn evaluate_partition(
- &self,
- _values: &[ArrayRef],
- _partition: Range<usize>,
- ) -> Result<ArrayRef> {
+ fn evaluate(&self, _values: &[ArrayRef], _num_rows: usize) -> Result<ArrayRef> {
Err(DataFusionError::NotImplemented(
"evaluate_partition is not implemented by default".into(),
))
}
/// evaluate the partition evaluator against the partition but with rank
- fn evaluate_partition_with_rank(
+ fn evaluate_with_rank(
&self,
- _partition: Range<usize>,
+ _num_rows: usize,
_ranks_in_partition: &[Range<usize>],
) -> Result<ArrayRef> {
Err(DataFusionError::NotImplemented(
diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs
index 8ed0319a1..87e01528d 100644
--- a/datafusion/physical-expr/src/window/rank.rs
+++ b/datafusion/physical-expr/src/window/rank.rs
@@ -114,9 +114,9 @@ impl PartitionEvaluator for RankEvaluator {
true
}
- fn evaluate_partition_with_rank(
+ fn evaluate_with_rank(
&self,
- partition: Range<usize>,
+ num_rows: usize,
ranks_in_partition: &[Range<usize>],
) -> Result<ArrayRef> {
// see https://www.postgresql.org/docs/current/functions-window.html
@@ -132,7 +132,7 @@ impl PartitionEvaluator for RankEvaluator {
)),
RankType::Percent => {
// Returns the relative rank of the current row, that is (rank - 1) / (total partition rows - 1). The value thus ranges from 0 to 1 inclusive.
- let denominator = (partition.end - partition.start) as f64;
+ let denominator = num_rows as f64;
Arc::new(Float64Array::from_iter_values(
ranks_in_partition
.iter()
@@ -177,15 +177,14 @@ mod tests {
fn test_f64_result(
expr: &Rank,
- range: Range<usize>,
+ num_rows: usize,
ranks: Vec<Range<usize>>,
expected: Vec<f64>,
) -> Result<()> {
let result = expr
.create_evaluator()?
- .evaluate_with_rank(vec![range], ranks)?;
- assert_eq!(1, result.len());
- let result = as_float64_array(&result[0])?;
+ .evaluate_with_rank(num_rows, &ranks)?;
+ let result = as_float64_array(&result)?;
let result = result.values();
assert_eq!(expected, result);
Ok(())
@@ -196,11 +195,8 @@ mod tests {
ranks: Vec<Range<usize>>,
expected: Vec<u64>,
) -> Result<()> {
- let result = expr
- .create_evaluator()?
- .evaluate_with_rank(vec![0..8], ranks)?;
- assert_eq!(1, result.len());
- let result = as_uint64_array(&result[0])?;
+ let result = expr.create_evaluator()?.evaluate_with_rank(8, &ranks)?;
+ let result = as_uint64_array(&result)?;
let result = result.values();
assert_eq!(expected, result);
Ok(())
@@ -228,19 +224,19 @@ mod tests {
// empty case
let expected = vec![0.0; 0];
- test_f64_result(&r, 0..0, vec![0..0; 0], expected)?;
+ test_f64_result(&r, 0, vec![0..0; 0], expected)?;
// singleton case
let expected = vec![0.0];
- test_f64_result(&r, 0..1, vec![0..1], expected)?;
+ test_f64_result(&r, 1, vec![0..1], expected)?;
// uniform case
let expected = vec![0.0; 7];
- test_f64_result(&r, 0..7, vec![0..7], expected)?;
+ test_f64_result(&r, 7, vec![0..7], expected)?;
// non-trivial case
let expected = vec![0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5];
- test_f64_result(&r, 0..7, vec![0..3, 3..7], expected)?;
+ test_f64_result(&r, 7, vec![0..3, 3..7], expected)?;
Ok(())
}
diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs
index f70d9ea37..b27ac29d2 100644
--- a/datafusion/physical-expr/src/window/row_number.rs
+++ b/datafusion/physical-expr/src/window/row_number.rs
@@ -24,7 +24,6 @@ use arrow::array::{ArrayRef, UInt64Array};
use arrow::datatypes::{DataType, Field};
use datafusion_common::Result;
use std::any::Any;
-use std::ops::Range;
use std::sync::Arc;
/// row_number expression
@@ -69,12 +68,7 @@ impl BuiltInWindowFunctionExpr for RowNumber {
pub(crate) struct NumRowsEvaluator {}
impl PartitionEvaluator for NumRowsEvaluator {
- fn evaluate_partition(
- &self,
- _values: &[ArrayRef],
- partition: Range<usize>,
- ) -> Result<ArrayRef> {
- let num_rows = partition.end - partition.start;
+ fn evaluate(&self, _values: &[ArrayRef], num_rows: usize) -> Result<ArrayRef> {
Ok(Arc::new(UInt64Array::from_iter_values(
1..(num_rows as u64) + 1,
)))
@@ -99,9 +93,8 @@ mod tests {
let values = row_number.evaluate_args(&batch)?;
let result = row_number
.create_evaluator()?
- .evaluate(&values, vec![0..8])?;
- assert_eq!(1, result.len());
- let result = as_uint64_array(&result[0])?;
+ .evaluate(&values, batch.num_rows())?;
+ let result = as_uint64_array(&result)?;
let result = result.values();
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);
Ok(())
@@ -118,9 +111,8 @@ mod tests {
let values = row_number.evaluate_args(&batch)?;
let result = row_number
.create_evaluator()?
- .evaluate(&values, vec![0..8])?;
- assert_eq!(1, result.len());
- let result = as_uint64_array(&result[0])?;
+ .evaluate(&values, batch.num_rows())?;
+ let result = as_uint64_array(&result)?;
let result = result.values();
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);
Ok(())
diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs
index 9dbaca76e..2a0fa86b7 100644
--- a/datafusion/physical-expr/src/window/sliding_aggregate.rs
+++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs
@@ -19,6 +19,7 @@
use std::any::Any;
use std::iter::IntoIterator;
+use std::ops::Range;
use std::sync::Arc;
use arrow::array::Array;
@@ -30,6 +31,8 @@ use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::WindowFrame;
+use crate::window::window_expr::reverse_order_bys;
+use crate::window::AggregateWindowExpr;
use crate::{expressions::PhysicalSortExpr, PhysicalExpr};
use crate::{window::WindowExpr, AggregateExpr};
@@ -89,57 +92,48 @@ impl WindowExpr for SlidingAggregateWindowExpr {
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
- let partition_columns = self.partition_columns(batch)?;
- let partition_points =
- self.evaluate_partition_points(batch.num_rows(), &partition_columns)?;
let sort_options: Vec<SortOptions> =
self.order_by.iter().map(|o| o.options).collect();
let mut row_wise_results: Vec<ScalarValue> = vec![];
- for partition_range in &partition_points {
- let mut accumulator = self.aggregate.create_sliding_accumulator()?;
- let length = partition_range.end - partition_range.start;
- let (values, order_bys) =
- self.get_values_orderbys(&batch.slice(partition_range.start, length))?;
-
- let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
- let mut last_range: (usize, usize) = (0, 0);
-
- // We iterate on each row to perform a running calculation.
- // First, cur_range is calculated, then it is compared with last_range.
- for i in 0..length {
- let cur_range = window_frame_ctx.calculate_range(
- &order_bys,
- &sort_options,
- length,
- i,
- )?;
- let value = if cur_range.0 == cur_range.1 {
- // We produce None if the window is empty.
- ScalarValue::try_from(self.aggregate.field()?.data_type())?
- } else {
- // Accumulate any new rows that have entered the window:
- let update_bound = cur_range.1 - last_range.1;
- if update_bound > 0 {
- let update: Vec<ArrayRef> = values
- .iter()
- .map(|v| v.slice(last_range.1, update_bound))
- .collect();
- accumulator.update_batch(&update)?
- }
- // Remove rows that have now left the window:
- let retract_bound = cur_range.0 - last_range.0;
- if retract_bound > 0 {
- let retract: Vec<ArrayRef> = values
- .iter()
- .map(|v| v.slice(last_range.0, retract_bound))
- .collect();
- accumulator.retract_batch(&retract)?
- }
- accumulator.evaluate()?
- };
- row_wise_results.push(value);
- last_range = cur_range;
- }
+
+ let mut accumulator = self.aggregate.create_sliding_accumulator()?;
+ let length = batch.num_rows();
+ let (values, order_bys) = self.get_values_orderbys(batch)?;
+
+ let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
+ let mut last_range = Range { start: 0, end: 0 };
+
+ // We iterate on each row to perform a running calculation.
+ // First, cur_range is calculated, then it is compared with last_range.
+ for i in 0..length {
+ let cur_range =
+ window_frame_ctx.calculate_range(&order_bys, &sort_options, length, i)?;
+ let value = if cur_range.start == cur_range.end {
+ // We produce None if the window is empty.
+ ScalarValue::try_from(self.aggregate.field()?.data_type())?
+ } else {
+ // Accumulate any new rows that have entered the window:
+ let update_bound = cur_range.end - last_range.end;
+ if update_bound > 0 {
+ let update: Vec<ArrayRef> = values
+ .iter()
+ .map(|v| v.slice(last_range.end, update_bound))
+ .collect();
+ accumulator.update_batch(&update)?
+ }
+ // Remove rows that have now left the window:
+ let retract_bound = cur_range.start - last_range.start;
+ if retract_bound > 0 {
+ let retract: Vec<ArrayRef> = values
+ .iter()
+ .map(|v| v.slice(last_range.start, retract_bound))
+ .collect();
+ accumulator.retract_batch(&retract)?
+ }
+ accumulator.evaluate()?
+ };
+ row_wise_results.push(value);
+ last_range = cur_range;
}
ScalarValue::iter_to_array(row_wise_results.into_iter())
}
@@ -155,4 +149,25 @@ impl WindowExpr for SlidingAggregateWindowExpr {
fn get_window_frame(&self) -> &Arc<WindowFrame> {
&self.window_frame
}
+
+ fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>> {
+ self.aggregate.reverse_expr().map(|reverse_expr| {
+ let reverse_window_frame = self.window_frame.reverse();
+ if reverse_window_frame.start_bound.is_unbounded() {
+ Arc::new(AggregateWindowExpr::new(
+ reverse_expr,
+ &self.partition_by.clone(),
+ &reverse_order_bys(&self.order_by),
+ Arc::new(self.window_frame.reverse()),
+ )) as _
+ } else {
+ Arc::new(SlidingAggregateWindowExpr::new(
+ reverse_expr,
+ &self.partition_by.clone(),
+ &reverse_order_bys(&self.order_by),
+ Arc::new(self.window_frame.reverse()),
+ )) as _
+ }
+ })
+ }
}
diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs
index 2fbc6e2c4..a718fa4cd 100644
--- a/datafusion/physical-expr/src/window/window_expr.rs
+++ b/datafusion/physical-expr/src/window/window_expr.rs
@@ -17,10 +17,10 @@
use crate::{PhysicalExpr, PhysicalSortExpr};
use arrow::compute::kernels::partition::lexicographical_partition_ranges;
-use arrow::compute::kernels::sort::{SortColumn, SortOptions};
+use arrow::compute::kernels::sort::SortColumn;
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, datatypes::Field};
-use datafusion_common::{DataFusionError, Result};
+use datafusion_common::{reverse_sort_options, DataFusionError, Result};
use datafusion_expr::WindowFrame;
use std::any::Any;
use std::fmt::Debug;
@@ -86,31 +86,6 @@ pub trait WindowExpr: Send + Sync + Debug {
/// expressions that's from the window function's order by clause, empty if absent
fn order_by(&self) -> &[PhysicalSortExpr];
- /// get partition columns that can be used for partitioning, empty if absent
- fn partition_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
- self.partition_by()
- .iter()
- .map(|expr| {
- if let Some(idx) =
- self.order_by().iter().position(|key| key.expr.eq(expr))
- {
- self.order_by()[idx].clone()
- } else {
- // When ASC is true, by default NULLS LAST to be consistent with PostgreSQL's rule:
- // https://www.postgresql.org/docs/current/queries-order.html
- PhysicalSortExpr {
- expr: expr.clone(),
- options: SortOptions {
- descending: false,
- nulls_first: false,
- },
- }
- }
- .evaluate_to_sort_column(batch)
- })
- .collect()
- }
-
/// get order by columns, empty if absent
fn order_by_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
self.order_by()
@@ -121,10 +96,8 @@ pub trait WindowExpr: Send + Sync + Debug {
/// get sort columns that can be used for peer evaluation, empty if absent
fn sort_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
- let mut sort_columns = self.partition_columns(batch)?;
let order_by_columns = self.order_by_columns(batch)?;
- sort_columns.extend(order_by_columns);
- Ok(sort_columns)
+ Ok(order_by_columns)
}
/// Get values columns(argument of Window Function)
@@ -140,6 +113,22 @@ pub trait WindowExpr: Send + Sync + Debug {
Ok((values, order_bys))
}
- // Get window frame of this WindowExpr, None if absent
+ /// Get the window frame of this [WindowExpr].
fn get_window_frame(&self) -> &Arc<WindowFrame>;
+
+ /// Get the reverse expression of this [WindowExpr].
+ fn get_reverse_expr(&self) -> Option<Arc<dyn WindowExpr>>;
+}
+
+/// Reverses the ORDER BY expression, which is useful during equivalent window
+/// expression construction. For instance, 'ORDER BY a ASC, NULLS LAST' turns into
+/// 'ORDER BY a DESC, NULLS FIRST'.
+pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec<PhysicalSortExpr> {
+ order_bys
+ .iter()
+ .map(|e| PhysicalSortExpr {
+ expr: e.expr.clone(),
+ options: reverse_sort_options(e.options),
+ })
+ .collect()
}
diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs
index 307ea9144..b49bd3a22 100644
--- a/datafusion/physical-expr/src/window/window_frame_state.rs
+++ b/datafusion/physical-expr/src/window/window_frame_state.rs
@@ -26,6 +26,7 @@ use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use std::cmp::min;
use std::collections::VecDeque;
use std::fmt::Debug;
+use std::ops::Range;
use std::sync::Arc;
/// This object stores the window frame state for use in incremental calculations.
@@ -68,7 +69,7 @@ impl<'a> WindowFrameContext<'a> {
sort_options: &[SortOptions],
length: usize,
idx: usize,
- ) -> Result<(usize, usize)> {
+ ) -> Result<Range<usize>> {
match *self {
WindowFrameContext::Rows(window_frame) => {
Self::calculate_range_rows(window_frame, length, idx)
@@ -99,7 +100,7 @@ impl<'a> WindowFrameContext<'a> {
window_frame: &Arc<WindowFrame>,
length: usize,
idx: usize,
- ) -> Result<(usize, usize)> {
+ ) -> Result<Range<usize>> {
let start = match window_frame.start_bound {
// UNBOUNDED PRECEDING
WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0,
@@ -152,7 +153,7 @@ impl<'a> WindowFrameContext<'a> {
return Err(DataFusionError::Internal("Rows should be Uint".to_string()))
}
};
- Ok((start, end))
+ Ok(Range { start, end })
}
}
@@ -171,7 +172,7 @@ impl WindowFrameStateRange {
sort_options: &[SortOptions],
length: usize,
idx: usize,
- ) -> Result<(usize, usize)> {
+ ) -> Result<Range<usize>> {
let start = match window_frame.start_bound {
WindowFrameBound::Preceding(ref n) => {
if n.is_null() {
@@ -240,7 +241,7 @@ impl WindowFrameStateRange {
}
}
};
- Ok((start, end))
+ Ok(Range { start, end })
}
/// This function does the heavy lifting when finding range boundaries. It is meant to be
@@ -333,7 +334,7 @@ impl WindowFrameStateGroups {
range_columns: &[ArrayRef],
length: usize,
idx: usize,
- ) -> Result<(usize, usize)> {
+ ) -> Result<Range<usize>> {
if range_columns.is_empty() {
return Err(DataFusionError::Execution(
"GROUPS mode requires an ORDER BY clause".to_string(),
@@ -399,7 +400,7 @@ impl WindowFrameStateGroups {
))
}
};
- Ok((start, end))
+ Ok(Range { start, end })
}
/// This function does the heavy lifting when finding group boundaries. It is meant to be