You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/10 12:02:06 UTC
[arrow-datafusion] branch master updated: [Part3] Partition and Sort Enforcement, Enforcement rule implementation (#4122)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 9c24a7911 [Part3] Partition and Sort Enforcement, Enforcement rule implementation (#4122)
9c24a7911 is described below
commit 9c24a79118c0ae0aee480bf039385a06d240b499
Author: mingmwang <mi...@ebay.com>
AuthorDate: Thu Nov 10 20:01:59 2022 +0800
[Part3] Partition and Sort Enforcement, Enforcement rule implementation (#4122)
* [Part3] Partition and Sort Enforcement, Enforcement rule implementation
* Avoid unncessary CoalescePartitionsExec in HashJoinExec and CrossJoinExec
* Fix join key ordering
* Fix join key reordering
* join key reordering, handle more operators explicitly
* Resolve review comments, add more UT to test reorder_join_keys_to_inputs
* add length check in fn expected_expr_positions()
---
datafusion/core/src/execution/context.rs | 9 +-
.../src/physical_optimizer/coalesce_batches.rs | 31 +-
.../core/src/physical_optimizer/enforcement.rs | 2001 ++++++++++++++++++++
datafusion/core/src/physical_optimizer/mod.rs | 1 +
.../core/src/physical_plan/aggregates/mod.rs | 22 +-
.../core/src/physical_plan/joins/cross_join.rs | 26 +-
.../core/src/physical_plan/joins/hash_join.rs | 24 +-
.../src/physical_plan/joins/sort_merge_join.rs | 13 +-
datafusion/core/src/physical_plan/joins/utils.rs | 28 +-
datafusion/core/src/physical_plan/mod.rs | 88 +-
datafusion/core/src/physical_plan/planner.rs | 61 +-
datafusion/core/src/physical_plan/projection.rs | 18 +-
datafusion/core/src/physical_plan/rewrite.rs | 165 ++
.../src/physical_plan/windows/window_agg_exec.rs | 4 +-
datafusion/physical-expr/src/equivalence.rs | 175 +-
15 files changed, 2467 insertions(+), 199 deletions(-)
diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs
index 8f5210998..d3989b5bd 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -73,7 +73,6 @@ use crate::optimizer::optimizer::{OptimizerConfig, OptimizerRule};
use datafusion_sql::{ResolvedTableReference, TableReference};
use crate::physical_optimizer::coalesce_batches::CoalesceBatches;
-use crate::physical_optimizer::merge_exec::AddCoalescePartitionsExec;
use crate::physical_optimizer::repartition::Repartition;
use crate::config::{
@@ -82,6 +81,7 @@ use crate::config::{
};
use crate::datasource::file_format::file_type::{FileCompressionType, FileType};
use crate::execution::{runtime_env::RuntimeEnv, FunctionRegistry};
+use crate::physical_optimizer::enforcement::BasicEnforcement;
use crate::physical_plan::file_format::{plan_to_csv, plan_to_json, plan_to_parquet};
use crate::physical_plan::planner::DefaultPhysicalPlanner;
use crate::physical_plan::udaf::AggregateUDF;
@@ -1227,6 +1227,8 @@ pub struct SessionConfig {
pub parquet_pruning: bool,
/// Should DataFusion collect statistics after listing files
pub collect_statistics: bool,
+ /// Should DataFusion optimizer run a top down process to reorder the join keys
+ pub top_down_join_key_reordering: bool,
/// Configuration options
pub config_options: Arc<RwLock<ConfigOptions>>,
/// Opaque extensions.
@@ -1246,6 +1248,7 @@ impl Default for SessionConfig {
repartition_windows: true,
parquet_pruning: true,
collect_statistics: false,
+ top_down_join_key_reordering: true,
config_options: Arc::new(RwLock::new(ConfigOptions::new())),
// Assume no extensions by default.
extensions: HashMap::with_capacity_and_hasher(
@@ -1568,6 +1571,7 @@ impl SessionState {
Arc::new(AggregateStatistics::new()),
Arc::new(HashBuildProbeOrder::new()),
];
+ physical_optimizers.push(Arc::new(BasicEnforcement::new()));
if config
.config_options
.read()
@@ -1585,7 +1589,8 @@ impl SessionState {
)));
}
physical_optimizers.push(Arc::new(Repartition::new()));
- physical_optimizers.push(Arc::new(AddCoalescePartitionsExec::new()));
+ physical_optimizers.push(Arc::new(BasicEnforcement::new()));
+ // physical_optimizers.push(Arc::new(AddCoalescePartitionsExec::new()));
SessionState {
session_id,
diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs
index 913046e95..0d4085478 100644
--- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs
+++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs
@@ -23,7 +23,7 @@ use crate::{
physical_optimizer::PhysicalOptimizerRule,
physical_plan::{
coalesce_batches::CoalesceBatchesExec, filter::FilterExec, joins::HashJoinExec,
- repartition::RepartitionExec, with_new_children_if_necessary,
+ repartition::RepartitionExec, rewrite::TreeNodeRewritable,
},
};
use std::sync::Arc;
@@ -48,34 +48,25 @@ impl PhysicalOptimizerRule for CoalesceBatches {
plan: Arc<dyn crate::physical_plan::ExecutionPlan>,
_config: &crate::execution::context::SessionConfig,
) -> Result<Arc<dyn crate::physical_plan::ExecutionPlan>> {
- if plan.children().is_empty() {
- // leaf node, children cannot be replaced
- Ok(plan.clone())
- } else {
- // recurse down first
- let children = plan
- .children()
- .iter()
- .map(|child| self.optimize(child.clone(), _config))
- .collect::<Result<Vec<_>>>()?;
- let plan = with_new_children_if_necessary(plan, children)?;
+ let target_batch_size = self.target_batch_size;
+ plan.transform_up(&|plan| {
+ let plan_any = plan.as_any();
// The goal here is to detect operators that could produce small batches and only
// wrap those ones with a CoalesceBatchesExec operator. An alternate approach here
// would be to build the coalescing logic directly into the operators
// See https://github.com/apache/arrow-datafusion/issues/139
- let plan_any = plan.as_any();
let wrap_in_coalesce = plan_any.downcast_ref::<FilterExec>().is_some()
|| plan_any.downcast_ref::<HashJoinExec>().is_some()
|| plan_any.downcast_ref::<RepartitionExec>().is_some();
- Ok(if wrap_in_coalesce {
- Arc::new(CoalesceBatchesExec::new(
+ if wrap_in_coalesce {
+ Some(Arc::new(CoalesceBatchesExec::new(
plan.clone(),
- self.target_batch_size,
- ))
+ target_batch_size,
+ )))
} else {
- plan.clone()
- })
- }
+ None
+ }
+ })
}
fn name(&self) -> &str {
diff --git a/datafusion/core/src/physical_optimizer/enforcement.rs b/datafusion/core/src/physical_optimizer/enforcement.rs
new file mode 100644
index 000000000..1eaf153a9
--- /dev/null
+++ b/datafusion/core/src/physical_optimizer/enforcement.rs
@@ -0,0 +1,2001 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Enforcement optimizer rules are used to make sure the plan's Distribution and Ordering
+//! requirements are met by inserting necessary [[RepartitionExec]] and [[SortExec]].
+//!
+use crate::error::Result;
+use crate::physical_optimizer::PhysicalOptimizerRule;
+use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
+use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
+use crate::physical_plan::joins::{
+ CrossJoinExec, HashJoinExec, PartitionMode, SortMergeJoinExec,
+};
+use crate::physical_plan::projection::ProjectionExec;
+use crate::physical_plan::repartition::RepartitionExec;
+use crate::physical_plan::rewrite::TreeNodeRewritable;
+use crate::physical_plan::sorts::sort::SortExec;
+use crate::physical_plan::windows::WindowAggExec;
+use crate::physical_plan::Partitioning;
+use crate::physical_plan::{with_new_children_if_necessary, Distribution, ExecutionPlan};
+use crate::prelude::SessionConfig;
+use datafusion_expr::logical_plan::JoinType;
+use datafusion_physical_expr::equivalence::EquivalenceProperties;
+use datafusion_physical_expr::expressions::Column;
+use datafusion_physical_expr::expressions::NoOp;
+use datafusion_physical_expr::{
+ expr_list_eq_strict_order, normalize_expr_with_equivalence_properties,
+ normalize_sort_expr_with_equivalence_properties, PhysicalExpr, PhysicalSortExpr,
+};
+use std::collections::HashMap;
+use std::sync::Arc;
+
+/// BasicEnforcement rule, it ensures the Distribution and Ordering requirements are met
+/// in the strictest way. It might add additional [[RepartitionExec]] to the plan tree
+/// and give a non-optimal plan, but it can avoid the possible data skew in joins
+///
+/// For example for a HashJoin with keys(a, b, c), the required Distribution(a, b, c) can be satisfied by
+/// several alternative partitioning ways: [(a, b, c), (a, b), (a, c), (b, c), (a), (b), (c), ( )].
+///
+/// This rule only chooses the exactly match and satisfies the Distribution(a, b, c) by a HashPartition(a, b, c).
+#[derive(Default)]
+pub struct BasicEnforcement {}
+
+impl BasicEnforcement {
+ #[allow(missing_docs)]
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+impl PhysicalOptimizerRule for BasicEnforcement {
+ fn optimize(
+ &self,
+ plan: Arc<dyn ExecutionPlan>,
+ config: &SessionConfig,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ let target_partitions = config.target_partitions;
+ let top_down_join_key_reordering = config.top_down_join_key_reordering;
+ let new_plan = if top_down_join_key_reordering {
+ // Run a top-down process to adjust input key ordering recursively
+ adjust_input_keys_down_recursively(plan, vec![])?
+ } else {
+ plan
+ };
+ // Distribution and Ordering enforcement need to be applied bottom-up.
+ new_plan.transform_up(&{
+ |plan| {
+ let adjusted = if !top_down_join_key_reordering {
+ reorder_join_keys_to_inputs(plan)
+ } else {
+ plan
+ };
+ Some(ensure_distribution_and_ordering(
+ adjusted,
+ target_partitions,
+ ))
+ }
+ })
+ }
+
+ fn name(&self) -> &str {
+ "BasicEnforcement"
+ }
+}
+
+/// When the physical planner creates the Joins, the ordering of join keys is from the original query.
+/// That might not match with the output partitioning of the join node's children
+/// This method runs a top-down process and try to adjust the output partitioning of the children
+/// if the children themselves are Joins or Aggregations.
+fn adjust_input_keys_down_recursively(
+ plan: Arc<dyn crate::physical_plan::ExecutionPlan>,
+ parent_required: Vec<Arc<dyn PhysicalExpr>>,
+) -> Result<Arc<dyn crate::physical_plan::ExecutionPlan>> {
+ let plan_any = plan.as_any();
+ if let Some(HashJoinExec {
+ left,
+ right,
+ on,
+ filter,
+ join_type,
+ mode,
+ null_equals_null,
+ ..
+ }) = plan_any.downcast_ref::<HashJoinExec>()
+ {
+ match mode {
+ PartitionMode::Partitioned => {
+ let join_key_pairs = extract_join_keys(on);
+ if let Some((
+ JoinKeyPairs {
+ left_keys,
+ right_keys,
+ },
+ new_positions,
+ )) = try_reorder(
+ join_key_pairs.clone(),
+ parent_required,
+ &plan.equivalence_properties(),
+ ) {
+ let new_join_on = if !new_positions.is_empty() {
+ new_join_conditions(&left_keys, &right_keys)
+ } else {
+ on.clone()
+ };
+ let new_left =
+ adjust_input_keys_down_recursively(left.clone(), left_keys)?;
+ let new_right =
+ adjust_input_keys_down_recursively(right.clone(), right_keys)?;
+ Ok(Arc::new(HashJoinExec::try_new(
+ new_left,
+ new_right,
+ new_join_on,
+ filter.clone(),
+ join_type,
+ PartitionMode::Partitioned,
+ null_equals_null,
+ )?))
+ } else {
+ let new_left = adjust_input_keys_down_recursively(
+ left.clone(),
+ join_key_pairs.left_keys,
+ )?;
+ let new_right = adjust_input_keys_down_recursively(
+ right.clone(),
+ join_key_pairs.right_keys,
+ )?;
+ Ok(Arc::new(HashJoinExec::try_new(
+ new_left,
+ new_right,
+ on.clone(),
+ filter.clone(),
+ join_type,
+ PartitionMode::Partitioned,
+ null_equals_null,
+ )?))
+ }
+ }
+ PartitionMode::CollectLeft => {
+ let new_left = adjust_input_keys_down_recursively(left.clone(), vec![])?;
+ let new_right = match join_type {
+ JoinType::Inner | JoinType::Right => try_push_required_to_right(
+ parent_required,
+ right.clone(),
+ left.schema().fields().len(),
+ )?,
+ JoinType::RightSemi | JoinType::RightAnti => {
+ adjust_input_keys_down_recursively(
+ right.clone(),
+ parent_required.clone(),
+ )?
+ }
+ JoinType::Left
+ | JoinType::LeftSemi
+ | JoinType::LeftAnti
+ | JoinType::Full => {
+ adjust_input_keys_down_recursively(right.clone(), vec![])?
+ }
+ };
+
+ Ok(Arc::new(HashJoinExec::try_new(
+ new_left,
+ new_right,
+ on.clone(),
+ filter.clone(),
+ join_type,
+ PartitionMode::CollectLeft,
+ null_equals_null,
+ )?))
+ }
+ }
+ } else if let Some(CrossJoinExec { left, right, .. }) =
+ plan_any.downcast_ref::<CrossJoinExec>()
+ {
+ let new_left = adjust_input_keys_down_recursively(left.clone(), vec![])?;
+ let new_right = try_push_required_to_right(
+ parent_required,
+ right.clone(),
+ left.schema().fields().len(),
+ )?;
+ Ok(Arc::new(CrossJoinExec::try_new(new_left, new_right)?))
+ } else if let Some(SortMergeJoinExec {
+ left,
+ right,
+ on,
+ join_type,
+ sort_options,
+ null_equals_null,
+ ..
+ }) = plan_any.downcast_ref::<SortMergeJoinExec>()
+ {
+ let join_key_pairs = extract_join_keys(on);
+ if let Some((
+ JoinKeyPairs {
+ left_keys,
+ right_keys,
+ },
+ new_positions,
+ )) = try_reorder(
+ join_key_pairs.clone(),
+ parent_required,
+ &plan.equivalence_properties(),
+ ) {
+ let new_join_on = if !new_positions.is_empty() {
+ new_join_conditions(&left_keys, &right_keys)
+ } else {
+ on.clone()
+ };
+ let new_options = if !new_positions.is_empty() {
+ let mut new_sort_options = vec![];
+ for idx in 0..sort_options.len() {
+ new_sort_options.push(sort_options[new_positions[idx]])
+ }
+ new_sort_options
+ } else {
+ sort_options.clone()
+ };
+
+ let new_left = adjust_input_keys_down_recursively(left.clone(), left_keys)?;
+ let new_right =
+ adjust_input_keys_down_recursively(right.clone(), right_keys)?;
+
+ Ok(Arc::new(SortMergeJoinExec::try_new(
+ new_left,
+ new_right,
+ new_join_on,
+ *join_type,
+ new_options,
+ *null_equals_null,
+ )?))
+ } else {
+ let new_left = adjust_input_keys_down_recursively(
+ left.clone(),
+ join_key_pairs.left_keys,
+ )?;
+ let new_right = adjust_input_keys_down_recursively(
+ right.clone(),
+ join_key_pairs.right_keys,
+ )?;
+ Ok(Arc::new(SortMergeJoinExec::try_new(
+ new_left,
+ new_right,
+ on.clone(),
+ *join_type,
+ sort_options.clone(),
+ *null_equals_null,
+ )?))
+ }
+ } else if let Some(AggregateExec {
+ mode,
+ group_by,
+ aggr_expr,
+ input,
+ input_schema,
+ ..
+ }) = plan_any.downcast_ref::<AggregateExec>()
+ {
+ if parent_required.is_empty() {
+ plan.map_children(|plan| adjust_input_keys_down_recursively(plan, vec![]))
+ } else {
+ match mode {
+ AggregateMode::Final => plan.map_children(|plan| {
+ adjust_input_keys_down_recursively(plan, vec![])
+ }),
+ AggregateMode::FinalPartitioned | AggregateMode::Partial => {
+ let out_put_columns = group_by
+ .expr()
+ .iter()
+ .enumerate()
+ .map(|(index, (_col, name))| Column::new(name, index))
+ .collect::<Vec<_>>();
+
+ let out_put_exprs = out_put_columns
+ .iter()
+ .map(|c| Arc::new(c.clone()) as Arc<dyn PhysicalExpr>)
+ .collect::<Vec<_>>();
+
+ // Check whether the requirements can be satisfied by the Aggregation
+ if parent_required.len() != out_put_exprs.len()
+ || expr_list_eq_strict_order(&out_put_exprs, &parent_required)
+ || !group_by.null_expr().is_empty()
+ {
+ plan.map_children(|plan| {
+ adjust_input_keys_down_recursively(plan, vec![])
+ })
+ } else {
+ let new_positions =
+ expected_expr_positions(&out_put_exprs, &parent_required);
+ match new_positions {
+ Some(positions) => {
+ let mut new_group_exprs = vec![];
+ for idx in positions.into_iter() {
+ new_group_exprs.push(group_by.expr()[idx].clone());
+ }
+ let new_group_by =
+ PhysicalGroupBy::new_single(new_group_exprs);
+ match mode {
+ AggregateMode::FinalPartitioned => {
+ // Since the input of FinalPartitioned should be the Partial AggregateExec and they should
+ // share the same column order, it's safe to call adjust_input_keys_down_recursively here
+ let new_input =
+ adjust_input_keys_down_recursively(
+ input.clone(),
+ parent_required,
+ )?;
+ let new_agg = Arc::new(AggregateExec::try_new(
+ AggregateMode::FinalPartitioned,
+ new_group_by,
+ aggr_expr.clone(),
+ new_input,
+ input_schema.clone(),
+ )?);
+
+ // Need to create a new projection to change the expr ordering back
+ let mut proj_exprs = out_put_columns
+ .iter()
+ .map(|col| {
+ (
+ Arc::new(Column::new(
+ col.name(),
+ new_agg
+ .schema()
+ .index_of(col.name())
+ .unwrap(),
+ ))
+ as Arc<dyn PhysicalExpr>,
+ col.name().to_owned(),
+ )
+ })
+ .collect::<Vec<_>>();
+ let agg_schema = new_agg.schema();
+ let agg_fields = agg_schema.fields();
+ for (idx, field) in agg_fields
+ .iter()
+ .enumerate()
+ .skip(out_put_columns.len())
+ {
+ proj_exprs.push((
+ Arc::new(Column::new(
+ field.name().as_str(),
+ idx,
+ ))
+ as Arc<dyn PhysicalExpr>,
+ field.name().clone(),
+ ))
+ }
+ // TODO merge adjacent Projections if there are
+ Ok(Arc::new(ProjectionExec::try_new(
+ proj_exprs, new_agg,
+ )?))
+ }
+ AggregateMode::Partial => {
+ let new_input =
+ adjust_input_keys_down_recursively(
+ input.clone(),
+ vec![],
+ )?;
+ Ok(Arc::new(AggregateExec::try_new(
+ AggregateMode::Partial,
+ new_group_by,
+ aggr_expr.clone(),
+ new_input,
+ input_schema.clone(),
+ )?))
+ }
+ _ => Ok(plan),
+ }
+ }
+ None => plan.map_children(|plan| {
+ adjust_input_keys_down_recursively(plan, vec![])
+ }),
+ }
+ }
+ }
+ }
+ }
+ } else if let Some(ProjectionExec { expr, .. }) =
+ plan_any.downcast_ref::<ProjectionExec>()
+ {
+ // For Projection, we need to transform the columns to the columns before the Projection
+ // And then to push down the requirements
+ // Construct a mapping from new name to the the orginal Column
+ let mut column_mapping = HashMap::new();
+ for (expression, name) in expr.iter() {
+ if let Some(column) = expression.as_any().downcast_ref::<Column>() {
+ column_mapping.insert(name.clone(), column.clone());
+ };
+ }
+ let new_required: Vec<Arc<dyn PhysicalExpr>> = parent_required
+ .iter()
+ .filter_map(|r| {
+ if let Some(column) = r.as_any().downcast_ref::<Column>() {
+ column_mapping.get(column.name())
+ } else {
+ None
+ }
+ })
+ .map(|e| Arc::new(e.clone()) as Arc<dyn PhysicalExpr>)
+ .collect::<Vec<_>>();
+ if new_required.len() == parent_required.len() {
+ plan.map_children(|plan| {
+ adjust_input_keys_down_recursively(plan, new_required.clone())
+ })
+ } else {
+ plan.map_children(|plan| adjust_input_keys_down_recursively(plan, vec![]))
+ }
+ } else if plan_any.downcast_ref::<RepartitionExec>().is_some()
+ || plan_any.downcast_ref::<CoalescePartitionsExec>().is_some()
+ || plan_any.downcast_ref::<WindowAggExec>().is_some()
+ {
+ plan.map_children(|plan| adjust_input_keys_down_recursively(plan, vec![]))
+ } else {
+ plan.map_children(|plan| {
+ adjust_input_keys_down_recursively(plan, parent_required.clone())
+ })
+ }
+}
+
+fn try_push_required_to_right(
+ parent_required: Vec<Arc<dyn PhysicalExpr>>,
+ right: Arc<dyn ExecutionPlan>,
+ left_columns_len: usize,
+) -> Result<Arc<dyn ExecutionPlan>> {
+ let new_required: Vec<Arc<dyn PhysicalExpr>> = parent_required
+ .iter()
+ .filter_map(|r| {
+ if let Some(col) = r.as_any().downcast_ref::<Column>() {
+ if col.index() >= left_columns_len {
+ Some(
+ Arc::new(Column::new(col.name(), col.index() - left_columns_len))
+ as Arc<dyn PhysicalExpr>,
+ )
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+ })
+ .collect::<Vec<_>>();
+
+ // if the parent required are all comming from the right side, the requirements can be pushdown
+ if new_required.len() == parent_required.len() {
+ adjust_input_keys_down_recursively(right.clone(), new_required)
+ } else {
+ adjust_input_keys_down_recursively(right.clone(), vec![])
+ }
+}
+
+/// When the physical planner creates the Joins, the ordering of join keys is from the original query.
+/// That might not match with the output partitioning of the join node's children
+/// This method will try to change the ordering of the join keys to match with the
+/// partitioning of the join nodes' children.
+/// If it can not match with both sides, it will try to match with one, either left side or right side.
+fn reorder_join_keys_to_inputs(
+ plan: Arc<dyn crate::physical_plan::ExecutionPlan>,
+) -> Arc<dyn crate::physical_plan::ExecutionPlan> {
+ let plan_any = plan.as_any();
+ if let Some(HashJoinExec {
+ left,
+ right,
+ on,
+ filter,
+ join_type,
+ mode,
+ null_equals_null,
+ ..
+ }) = plan_any.downcast_ref::<HashJoinExec>()
+ {
+ match mode {
+ PartitionMode::Partitioned => {
+ let join_key_pairs = extract_join_keys(on);
+ if let Some((
+ JoinKeyPairs {
+ left_keys,
+ right_keys,
+ },
+ new_positions,
+ )) = reorder_current_join_keys(
+ join_key_pairs,
+ Some(left.output_partitioning()),
+ Some(right.output_partitioning()),
+ &left.equivalence_properties(),
+ &right.equivalence_properties(),
+ ) {
+ if !new_positions.is_empty() {
+ let new_join_on = new_join_conditions(&left_keys, &right_keys);
+ Arc::new(
+ HashJoinExec::try_new(
+ left.clone(),
+ right.clone(),
+ new_join_on,
+ filter.clone(),
+ join_type,
+ PartitionMode::Partitioned,
+ null_equals_null,
+ )
+ .unwrap(),
+ )
+ } else {
+ plan
+ }
+ } else {
+ plan
+ }
+ }
+ _ => plan,
+ }
+ } else if let Some(SortMergeJoinExec {
+ left,
+ right,
+ on,
+ join_type,
+ sort_options,
+ null_equals_null,
+ ..
+ }) = plan_any.downcast_ref::<SortMergeJoinExec>()
+ {
+ let join_key_pairs = extract_join_keys(on);
+ if let Some((
+ JoinKeyPairs {
+ left_keys,
+ right_keys,
+ },
+ new_positions,
+ )) = reorder_current_join_keys(
+ join_key_pairs,
+ Some(left.output_partitioning()),
+ Some(right.output_partitioning()),
+ &left.equivalence_properties(),
+ &right.equivalence_properties(),
+ ) {
+ if !new_positions.is_empty() {
+ let new_join_on = new_join_conditions(&left_keys, &right_keys);
+ let mut new_sort_options = vec![];
+ for idx in 0..sort_options.len() {
+ new_sort_options.push(sort_options[new_positions[idx]])
+ }
+ Arc::new(
+ SortMergeJoinExec::try_new(
+ left.clone(),
+ right.clone(),
+ new_join_on,
+ *join_type,
+ new_sort_options,
+ *null_equals_null,
+ )
+ .unwrap(),
+ )
+ } else {
+ plan
+ }
+ } else {
+ plan
+ }
+ } else {
+ plan
+ }
+}
+
+/// Reorder the current join keys ordering based on either left partition or right partition
+fn reorder_current_join_keys(
+ join_keys: JoinKeyPairs,
+ left_partition: Option<Partitioning>,
+ right_partition: Option<Partitioning>,
+ left_equivalence_properties: &EquivalenceProperties,
+ right_equivalence_properties: &EquivalenceProperties,
+) -> Option<(JoinKeyPairs, Vec<usize>)> {
+ match (left_partition, right_partition.clone()) {
+ (Some(Partitioning::Hash(left_exprs, _)), _) => {
+ try_reorder(join_keys.clone(), left_exprs, left_equivalence_properties)
+ .or_else(|| {
+ reorder_current_join_keys(
+ join_keys,
+ None,
+ right_partition,
+ left_equivalence_properties,
+ right_equivalence_properties,
+ )
+ })
+ }
+ (_, Some(Partitioning::Hash(right_exprs, _))) => {
+ try_reorder(join_keys, right_exprs, right_equivalence_properties)
+ }
+ _ => None,
+ }
+}
+
+fn try_reorder(
+ join_keys: JoinKeyPairs,
+ expected: Vec<Arc<dyn PhysicalExpr>>,
+ equivalence_properties: &EquivalenceProperties,
+) -> Option<(JoinKeyPairs, Vec<usize>)> {
+ let mut normalized_expected = vec![];
+ let mut normalized_left_keys = vec![];
+ let mut normalized_right_keys = vec![];
+ if join_keys.left_keys.len() != expected.len() {
+ return None;
+ }
+ if expr_list_eq_strict_order(&expected, &join_keys.left_keys)
+ || expr_list_eq_strict_order(&expected, &join_keys.right_keys)
+ {
+ return Some((join_keys, vec![]));
+ } else if !equivalence_properties.classes().is_empty() {
+ normalized_expected = expected
+ .iter()
+ .map(|e| {
+ normalize_expr_with_equivalence_properties(
+ e.clone(),
+ equivalence_properties.classes(),
+ )
+ })
+ .collect::<Vec<_>>();
+ assert_eq!(normalized_expected.len(), expected.len());
+
+ normalized_left_keys = join_keys
+ .left_keys
+ .iter()
+ .map(|e| {
+ normalize_expr_with_equivalence_properties(
+ e.clone(),
+ equivalence_properties.classes(),
+ )
+ })
+ .collect::<Vec<_>>();
+ assert_eq!(join_keys.left_keys.len(), normalized_left_keys.len());
+
+ normalized_right_keys = join_keys
+ .right_keys
+ .iter()
+ .map(|e| {
+ normalize_expr_with_equivalence_properties(
+ e.clone(),
+ equivalence_properties.classes(),
+ )
+ })
+ .collect::<Vec<_>>();
+ assert_eq!(join_keys.right_keys.len(), normalized_right_keys.len());
+
+ if expr_list_eq_strict_order(&normalized_expected, &normalized_left_keys)
+ || expr_list_eq_strict_order(&normalized_expected, &normalized_right_keys)
+ {
+ return Some((join_keys, vec![]));
+ }
+ }
+
+ let new_positions = expected_expr_positions(&join_keys.left_keys, &expected)
+ .or_else(|| expected_expr_positions(&join_keys.right_keys, &expected))
+ .or_else(|| expected_expr_positions(&normalized_left_keys, &normalized_expected))
+ .or_else(|| {
+ expected_expr_positions(&normalized_right_keys, &normalized_expected)
+ });
+
+ if let Some(positions) = new_positions {
+ let mut new_left_keys = vec![];
+ let mut new_right_keys = vec![];
+ for pos in positions.iter() {
+ new_left_keys.push(join_keys.left_keys[*pos].clone());
+ new_right_keys.push(join_keys.right_keys[*pos].clone());
+ }
+ Some((
+ JoinKeyPairs {
+ left_keys: new_left_keys,
+ right_keys: new_right_keys,
+ },
+ positions,
+ ))
+ } else {
+ None
+ }
+}
+
+/// Return the expected expressions positions.
+/// For example, the current expressions are ['c', 'a', 'a', b'], the expected expressions are ['b', 'c', 'a', 'a'],
+///
+/// This method will return a Vec [3, 0, 1, 2]
+fn expected_expr_positions(
+ current: &[Arc<dyn PhysicalExpr>],
+ expected: &[Arc<dyn PhysicalExpr>],
+) -> Option<Vec<usize>> {
+ if current.is_empty() || expected.is_empty() {
+ return None;
+ }
+ let mut indexes: Vec<usize> = vec![];
+ let mut current = current.to_vec();
+ for expr in expected.iter() {
+ // Find the position of the expected expr in the current expressions
+ if let Some(expected_position) = current.iter().position(|e| e.eq(expr)) {
+ current[expected_position] = Arc::new(NoOp::new());
+ indexes.push(expected_position);
+ } else {
+ return None;
+ }
+ }
+ Some(indexes)
+}
+
+fn extract_join_keys(on: &[(Column, Column)]) -> JoinKeyPairs {
+ let (left_keys, right_keys) = on
+ .iter()
+ .map(|(l, r)| {
+ (
+ Arc::new(l.clone()) as Arc<dyn PhysicalExpr>,
+ Arc::new(r.clone()) as Arc<dyn PhysicalExpr>,
+ )
+ })
+ .unzip();
+ JoinKeyPairs {
+ left_keys,
+ right_keys,
+ }
+}
+
+fn new_join_conditions(
+ new_left_keys: &[Arc<dyn PhysicalExpr>],
+ new_right_keys: &[Arc<dyn PhysicalExpr>],
+) -> Vec<(Column, Column)> {
+ let new_join_on = new_left_keys
+ .iter()
+ .zip(new_right_keys.iter())
+ .map(|(l_key, r_key)| {
+ (
+ l_key.as_any().downcast_ref::<Column>().unwrap().clone(),
+ r_key.as_any().downcast_ref::<Column>().unwrap().clone(),
+ )
+ })
+ .collect::<Vec<_>>();
+ new_join_on
+}
+
+fn ensure_distribution_and_ordering(
+ plan: Arc<dyn crate::physical_plan::ExecutionPlan>,
+ target_partitions: usize,
+) -> Arc<dyn crate::physical_plan::ExecutionPlan> {
+ if plan.children().is_empty() {
+ return plan;
+ }
+ let required_input_distributions = plan.required_input_distribution();
+ let required_input_orderings = plan.required_input_ordering();
+ let children: Vec<Arc<dyn ExecutionPlan>> = plan.children();
+ assert_eq!(children.len(), required_input_distributions.len());
+ assert_eq!(children.len(), required_input_orderings.len());
+
+ // Add RepartitionExec to guarantee output partitioning
+ let children = children
+ .into_iter()
+ .zip(required_input_distributions.into_iter())
+ .map(|(child, required)| {
+ if child
+ .output_partitioning()
+ .satisfy(required.clone(), || child.equivalence_properties())
+ {
+ child
+ } else {
+ let new_child: Arc<dyn ExecutionPlan> = match required {
+ Distribution::SinglePartition
+ if child.output_partitioning().partition_count() > 1 =>
+ {
+ Arc::new(CoalescePartitionsExec::new(child.clone()))
+ }
+ _ => {
+ let partition = required.create_partitioning(target_partitions);
+ Arc::new(RepartitionExec::try_new(child, partition).unwrap())
+ }
+ };
+ new_child
+ }
+ });
+
+ // Add SortExec to guarantee output ordering
+ let new_children: Vec<Arc<dyn ExecutionPlan>> = children
+ .zip(required_input_orderings.into_iter())
+ .map(|(child, required)| {
+ if ordering_satisfy(child.output_ordering(), required, || {
+ child.equivalence_properties()
+ }) {
+ child
+ } else {
+ let sort_expr = required.unwrap().to_vec();
+ if child.output_partitioning().partition_count() > 1 {
+ Arc::new(SortExec::new_with_partitioning(
+ sort_expr, child, true, None,
+ ))
+ } else {
+ Arc::new(SortExec::try_new(sort_expr, child, None).unwrap())
+ }
+ }
+ })
+ .collect::<Vec<_>>();
+
+ with_new_children_if_necessary(plan, new_children).unwrap()
+}
+
+/// Check the required ordering requirements are satisfied by the provided PhysicalSortExprs.
+fn ordering_satisfy<F: FnOnce() -> EquivalenceProperties>(
+ provided: Option<&[PhysicalSortExpr]>,
+ required: Option<&[PhysicalSortExpr]>,
+ equal_properties: F,
+) -> bool {
+ match (provided, required) {
+ (_, None) => true,
+ (None, Some(_)) => false,
+ (Some(provided), Some(required)) => {
+ if required.len() > provided.len() {
+ false
+ } else {
+ let fast_match = required
+ .iter()
+ .zip(provided.iter())
+ .all(|(order1, order2)| order1.eq(order2));
+
+ if !fast_match {
+ let eq_properties = equal_properties();
+ let eq_classes = eq_properties.classes();
+ if !eq_classes.is_empty() {
+ let normalized_required_exprs = required
+ .iter()
+ .map(|e| {
+ normalize_sort_expr_with_equivalence_properties(
+ e.clone(),
+ eq_classes,
+ )
+ })
+ .collect::<Vec<_>>();
+ let normalized_provided_exprs = provided
+ .iter()
+ .map(|e| {
+ normalize_sort_expr_with_equivalence_properties(
+ e.clone(),
+ eq_classes,
+ )
+ })
+ .collect::<Vec<_>>();
+ normalized_required_exprs
+ .iter()
+ .zip(normalized_provided_exprs.iter())
+ .all(|(order1, order2)| order1.eq(order2))
+ } else {
+ fast_match
+ }
+ } else {
+ fast_match
+ }
+ }
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+struct JoinKeyPairs {
+ left_keys: Vec<Arc<dyn PhysicalExpr>>,
+ right_keys: Vec<Arc<dyn PhysicalExpr>>,
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::physical_plan::filter::FilterExec;
+ use arrow::compute::SortOptions;
+ use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+ use datafusion_expr::logical_plan::JoinType;
+ use datafusion_expr::Operator;
+ use datafusion_physical_expr::expressions::binary;
+ use datafusion_physical_expr::expressions::lit;
+ use datafusion_physical_expr::expressions::Column;
+ use datafusion_physical_expr::{expressions, PhysicalExpr};
+ use std::ops::Deref;
+
+ use super::*;
+ use crate::config::ConfigOptions;
+ use crate::datasource::listing::PartitionedFile;
+ use crate::datasource::object_store::ObjectStoreUrl;
+ use crate::physical_plan::aggregates::{
+ AggregateExec, AggregateMode, PhysicalGroupBy,
+ };
+ use crate::physical_plan::expressions::col;
+ use crate::physical_plan::file_format::{FileScanConfig, ParquetExec};
+ use crate::physical_plan::joins::{
+ utils::JoinOn, HashJoinExec, PartitionMode, SortMergeJoinExec,
+ };
+ use crate::physical_plan::projection::ProjectionExec;
+ use crate::physical_plan::{displayable, Statistics};
+
+ fn schema() -> SchemaRef {
+ Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int64, true),
+ Field::new("b", DataType::Int64, true),
+ Field::new("c", DataType::Int64, true),
+ Field::new("d", DataType::Int32, true),
+ Field::new("e", DataType::Boolean, true),
+ ]))
+ }
+
+ fn parquet_exec() -> Arc<ParquetExec> {
+ Arc::new(ParquetExec::new(
+ FileScanConfig {
+ object_store_url: ObjectStoreUrl::parse("test:///").unwrap(),
+ file_schema: schema(),
+ file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]],
+ statistics: Statistics::default(),
+ projection: None,
+ limit: None,
+ table_partition_cols: vec![],
+ config_options: ConfigOptions::new().into_shareable(),
+ },
+ None,
+ None,
+ ))
+ }
+
+ fn projection_exec_with_alias(
+ input: Arc<dyn ExecutionPlan>,
+ alias_pairs: Vec<(String, String)>,
+ ) -> Arc<dyn ExecutionPlan> {
+ let mut exprs = vec![];
+ for (column, alias) in alias_pairs.iter() {
+ exprs.push((col(column, &input.schema()).unwrap(), alias.to_string()));
+ }
+ Arc::new(ProjectionExec::try_new(exprs, input).unwrap())
+ }
+
+ fn aggregate_exec_with_alias(
+ input: Arc<dyn ExecutionPlan>,
+ alias_pairs: Vec<(String, String)>,
+ ) -> Arc<dyn ExecutionPlan> {
+ let schema = schema();
+ let mut group_by_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
+ for (column, alias) in alias_pairs.iter() {
+ group_by_expr
+ .push((col(column, &input.schema()).unwrap(), alias.to_string()));
+ }
+ let group_by = PhysicalGroupBy::new_single(group_by_expr.clone());
+
+ let final_group_by_expr = group_by_expr
+ .iter()
+ .enumerate()
+ .map(|(index, (_col, name))| {
+ (
+ Arc::new(expressions::Column::new(name, index))
+ as Arc<dyn PhysicalExpr>,
+ name.clone(),
+ )
+ })
+ .collect::<Vec<_>>();
+ let final_grouping = PhysicalGroupBy::new_single(final_group_by_expr);
+
+ Arc::new(
+ AggregateExec::try_new(
+ AggregateMode::FinalPartitioned,
+ final_grouping,
+ vec![],
+ Arc::new(
+ AggregateExec::try_new(
+ AggregateMode::Partial,
+ group_by,
+ vec![],
+ input,
+ schema.clone(),
+ )
+ .unwrap(),
+ ),
+ schema,
+ )
+ .unwrap(),
+ )
+ }
+
+ fn hash_join_exec(
+ left: Arc<dyn ExecutionPlan>,
+ right: Arc<dyn ExecutionPlan>,
+ join_on: &JoinOn,
+ join_type: &JoinType,
+ ) -> Arc<dyn ExecutionPlan> {
+ Arc::new(
+ HashJoinExec::try_new(
+ left,
+ right,
+ join_on.clone(),
+ None,
+ join_type,
+ PartitionMode::Partitioned,
+ &false,
+ )
+ .unwrap(),
+ )
+ }
+
+ fn sort_merge_join_exec(
+ left: Arc<dyn ExecutionPlan>,
+ right: Arc<dyn ExecutionPlan>,
+ join_on: &JoinOn,
+ join_type: &JoinType,
+ ) -> Arc<dyn ExecutionPlan> {
+ Arc::new(
+ SortMergeJoinExec::try_new(
+ left,
+ right,
+ join_on.clone(),
+ *join_type,
+ vec![SortOptions::default(); join_on.len()],
+ false,
+ )
+ .unwrap(),
+ )
+ }
+
+ fn trim_plan_display(plan: &str) -> Vec<&str> {
+ plan.split('\n')
+ .map(|s| s.trim())
+ .filter(|s| !s.is_empty())
+ .collect()
+ }
+
+ /// Runs the repartition optimizer and asserts the plan against the expected
+ macro_rules! assert_optimized {
+ ($EXPECTED_LINES: expr, $PLAN: expr) => {
+ let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect();
+
+ // run optimizer
+ let optimizer = BasicEnforcement {};
+ let optimized = optimizer
+ .optimize($PLAN, &SessionConfig::new().with_target_partitions(10))?;
+
+ // Now format correctly
+ let plan = displayable(optimized.as_ref()).indent().to_string();
+ let actual_lines = trim_plan_display(&plan);
+
+ assert_eq!(
+ &expected_lines, &actual_lines,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected_lines, actual_lines
+ );
+ };
+ }
+
+ macro_rules! assert_plan_txt {
+ ($EXPECTED_LINES: expr, $PLAN: expr) => {
+ let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect();
+ // Now format correctly
+ let plan = displayable($PLAN.as_ref()).indent().to_string();
+ let actual_lines = trim_plan_display(&plan);
+
+ assert_eq!(
+ &expected_lines, &actual_lines,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected_lines, actual_lines
+ );
+ };
+ }
+
+ #[test]
+ fn multi_hash_joins() -> Result<()> {
+ let left = parquet_exec();
+ let alias_pairs: Vec<(String, String)> = vec![
+ ("a".to_string(), "a1".to_string()),
+ ("b".to_string(), "b1".to_string()),
+ ("c".to_string(), "c1".to_string()),
+ ("d".to_string(), "d1".to_string()),
+ ("e".to_string(), "e1".to_string()),
+ ];
+ let right = projection_exec_with_alias(parquet_exec(), alias_pairs);
+ let join_types = vec![
+ JoinType::Inner,
+ JoinType::Left,
+ JoinType::Right,
+ JoinType::Full,
+ JoinType::LeftSemi,
+ JoinType::LeftAnti,
+ JoinType::RightSemi,
+ JoinType::RightAnti,
+ ];
+
+ // Join on (a == b1)
+ let join_on = vec![(
+ Column::new_with_schema("a", &schema()).unwrap(),
+ Column::new_with_schema("b1", &right.schema()).unwrap(),
+ )];
+
+ for join_type in join_types {
+ let join = hash_join_exec(left.clone(), right.clone(), &join_on, &join_type);
+ let join_plan =
+ format!("HashJoinExec: mode=Partitioned, join_type={}, on=[(Column {{ name: \"a\", index: 0 }}, Column {{ name: \"b1\", index: 1 }})]", join_type);
+
+ match join_type {
+ JoinType::Inner
+ | JoinType::Left
+ | JoinType::Right
+ | JoinType::Full
+ | JoinType::LeftSemi
+ | JoinType::LeftAnti => {
+ // Join on (a == c)
+ let top_join_on = vec![(
+ Column::new_with_schema("a", &join.schema()).unwrap(),
+ Column::new_with_schema("c", &schema()).unwrap(),
+ )];
+ let top_join = hash_join_exec(
+ join.clone(),
+ parquet_exec(),
+ &top_join_on,
+ &join_type,
+ );
+ let top_join_plan =
+ format!("HashJoinExec: mode=Partitioned, join_type={}, on=[(Column {{ name: \"a\", index: 0 }}, Column {{ name: \"c\", index: 2 }})]", join_type);
+
+ let expected = match join_type {
+ // Should include 3 RepartitionExecs
+ JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![
+ top_join_plan.as_str(),
+ join_plan.as_str(),
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)",
+ "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ],
+ // Should include 4 RepartitionExecs
+ _ => vec![
+ top_join_plan.as_str(),
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)",
+ join_plan.as_str(),
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)",
+ "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ],
+ };
+ assert_optimized!(expected, top_join);
+ }
+ JoinType::RightSemi | JoinType::RightAnti => {}
+ }
+
+ match join_type {
+ JoinType::Inner
+ | JoinType::Left
+ | JoinType::Right
+ | JoinType::Full
+ | JoinType::RightSemi
+ | JoinType::RightAnti => {
+ // This time we use (b1 == c) for top join
+ // Join on (b1 == c)
+ let top_join_on = vec![(
+ Column::new_with_schema("b1", &join.schema()).unwrap(),
+ Column::new_with_schema("c", &schema()).unwrap(),
+ )];
+
+ let top_join =
+ hash_join_exec(join, parquet_exec(), &top_join_on, &join_type);
+ let top_join_plan = match join_type {
+ JoinType::RightSemi | JoinType::RightAnti =>
+ format!("HashJoinExec: mode=Partitioned, join_type={}, on=[(Column {{ name: \"b1\", index: 1 }}, Column {{ name: \"c\", index: 2 }})]", join_type),
+ _ =>
+ format!("HashJoinExec: mode=Partitioned, join_type={}, on=[(Column {{ name: \"b1\", index: 6 }}, Column {{ name: \"c\", index: 2 }})]", join_type),
+ };
+
+ let expected = match join_type {
+ // Should include 3 RepartitionExecs
+ JoinType::Inner | JoinType::Right | JoinType::RightSemi | JoinType::RightAnti =>
+ vec![
+ top_join_plan.as_str(),
+ join_plan.as_str(),
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)",
+ "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ],
+ // Should include 4 RepartitionExecs
+ _ =>
+ vec![
+ top_join_plan.as_str(),
+ "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 6 }], 10)",
+ join_plan.as_str(),
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)",
+ "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ],
+ };
+ assert_optimized!(expected, top_join);
+ }
+ JoinType::LeftSemi | JoinType::LeftAnti => {}
+ }
+ }
+
+ Ok(())
+ }
+
+ #[test]
+ fn multi_joins_after_alias() -> Result<()> {
+ let left = parquet_exec();
+ let right = parquet_exec();
+
+ // Join on (a == b)
+ let join_on = vec![(
+ Column::new_with_schema("a", &schema()).unwrap(),
+ Column::new_with_schema("b", &schema()).unwrap(),
+ )];
+ let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner);
+
+ // Projection(a as a1, a as a2)
+ let alias_pairs: Vec<(String, String)> = vec![
+ ("a".to_string(), "a1".to_string()),
+ ("a".to_string(), "a2".to_string()),
+ ];
+ let projection = projection_exec_with_alias(join, alias_pairs);
+
+ // Join on (a1 == c)
+ let top_join_on = vec![(
+ Column::new_with_schema("a1", &projection.schema()).unwrap(),
+ Column::new_with_schema("c", &schema()).unwrap(),
+ )];
+
+ let top_join = hash_join_exec(
+ projection.clone(),
+ right.clone(),
+ &top_join_on,
+ &JoinType::Inner,
+ );
+
+ // Output partition need to respect the Alias and should not introduce additional RepartitionExec
+ let expected = &[
+ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a1\", index: 0 }, Column { name: \"c\", index: 2 })]",
+ "ProjectionExec: expr=[a@0 as a1, a@0 as a2]",
+ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"b\", index: 1 })]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 1 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ];
+ assert_optimized!(expected, top_join);
+
+ // Join on (a2 == c)
+ let top_join_on = vec![(
+ Column::new_with_schema("a2", &projection.schema()).unwrap(),
+ Column::new_with_schema("c", &schema()).unwrap(),
+ )];
+
+ let top_join = hash_join_exec(projection, right, &top_join_on, &JoinType::Inner);
+
+ // Output partition need to respect the Alias and should not introduce additional RepartitionExec
+ let expected = &[
+ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a2\", index: 1 }, Column { name: \"c\", index: 2 })]",
+ "ProjectionExec: expr=[a@0 as a1, a@0 as a2]",
+ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"b\", index: 1 })]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 1 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ];
+
+ assert_optimized!(expected, top_join);
+ Ok(())
+ }
+
+ #[test]
+ fn multi_joins_after_multi_alias() -> Result<()> {
+ let left = parquet_exec();
+ let right = parquet_exec();
+
+ // Join on (a == b)
+ let join_on = vec![(
+ Column::new_with_schema("a", &schema()).unwrap(),
+ Column::new_with_schema("b", &schema()).unwrap(),
+ )];
+
+ let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner);
+
+ // Projection(c as c1)
+ let alias_pairs: Vec<(String, String)> =
+ vec![("c".to_string(), "c1".to_string())];
+ let projection = projection_exec_with_alias(join, alias_pairs);
+
+ // Projection(c1 as a)
+ let alias_pairs: Vec<(String, String)> =
+ vec![("c1".to_string(), "a".to_string())];
+ let projection2 = projection_exec_with_alias(projection, alias_pairs);
+
+ // Join on (a == c)
+ let top_join_on = vec![(
+ Column::new_with_schema("a", &projection2.schema()).unwrap(),
+ Column::new_with_schema("c", &schema()).unwrap(),
+ )];
+
+ let top_join = hash_join_exec(projection2, right, &top_join_on, &JoinType::Inner);
+
+ // The Column 'a' has different meaning now after the two Projections
+ // The original Output partition can not satisfy the Join requirements and need to add an additional RepartitionExec
+ let expected = &[
+ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"c\", index: 2 })]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)",
+ "ProjectionExec: expr=[c1@0 as a]",
+ "ProjectionExec: expr=[c@2 as c1]",
+ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"b\", index: 1 })]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 1 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ];
+
+ assert_optimized!(expected, top_join);
+ Ok(())
+ }
+
+ #[test]
+ fn join_after_agg_alias() -> Result<()> {
+ // group by (a as a1)
+ let left = aggregate_exec_with_alias(
+ parquet_exec(),
+ vec![("a".to_string(), "a1".to_string())],
+ );
+ // group by (a as a2)
+ let right = aggregate_exec_with_alias(
+ parquet_exec(),
+ vec![("a".to_string(), "a2".to_string())],
+ );
+
+ // Join on (a1 == a2)
+ let join_on = vec![(
+ Column::new_with_schema("a1", &left.schema()).unwrap(),
+ Column::new_with_schema("a2", &right.schema()).unwrap(),
+ )];
+ let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner);
+
+ // Only two RepartitionExecs added
+ let expected = &[
+ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a1\", index: 0 }, Column { name: \"a2\", index: 0 })]",
+ "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"a1\", index: 0 }], 10)",
+ "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "AggregateExec: mode=FinalPartitioned, gby=[a2@0 as a2], aggr=[]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"a2\", index: 0 }], 10)",
+ "AggregateExec: mode=Partial, gby=[a@0 as a2], aggr=[]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ];
+ assert_optimized!(expected, join);
+ Ok(())
+ }
+
+ #[test]
+ fn hash_join_key_ordering() -> Result<()> {
+ // group by (a as a1, b as b1)
+ let left = aggregate_exec_with_alias(
+ parquet_exec(),
+ vec![
+ ("a".to_string(), "a1".to_string()),
+ ("b".to_string(), "b1".to_string()),
+ ],
+ );
+ // group by (b, a)
+ let right = aggregate_exec_with_alias(
+ parquet_exec(),
+ vec![
+ ("b".to_string(), "b".to_string()),
+ ("a".to_string(), "a".to_string()),
+ ],
+ );
+
+ // Join on (b1 == b && a1 == a)
+ let join_on = vec![
+ (
+ Column::new_with_schema("b1", &left.schema()).unwrap(),
+ Column::new_with_schema("b", &right.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("a1", &left.schema()).unwrap(),
+ Column::new_with_schema("a", &right.schema()).unwrap(),
+ ),
+ ];
+ let join = hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner);
+
+ // Only two RepartitionExecs added
+ let expected = &[
+ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"b1\", index: 1 }, Column { name: \"b\", index: 0 }), (Column { name: \"a1\", index: 0 }, Column { name: \"a\", index: 1 })]",
+ "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]",
+ "AggregateExec: mode=FinalPartitioned, gby=[b1@1 as b1, a1@0 as a1], aggr=[]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 0 }, Column { name: \"a1\", index: 1 }], 10)",
+ "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 0 }, Column { name: \"a\", index: 1 }], 10)",
+ "AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ];
+ assert_optimized!(expected, join);
+ Ok(())
+ }
+
+ #[test]
+ fn multi_hash_join_key_ordering() -> Result<()> {
+ let left = parquet_exec();
+ let alias_pairs: Vec<(String, String)> = vec![
+ ("a".to_string(), "a1".to_string()),
+ ("b".to_string(), "b1".to_string()),
+ ("c".to_string(), "c1".to_string()),
+ ];
+ let right = projection_exec_with_alias(parquet_exec(), alias_pairs);
+
+ // Join on (a == a1 and b == b1 and c == c1)
+ let join_on = vec![
+ (
+ Column::new_with_schema("a", &schema()).unwrap(),
+ Column::new_with_schema("a1", &right.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("b", &schema()).unwrap(),
+ Column::new_with_schema("b1", &right.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("c", &schema()).unwrap(),
+ Column::new_with_schema("c1", &right.schema()).unwrap(),
+ ),
+ ];
+ let bottom_left_join =
+ hash_join_exec(left.clone(), right.clone(), &join_on, &JoinType::Inner);
+
+ // Projection(a as A, a as AA, b as B, c as C)
+ let alias_pairs: Vec<(String, String)> = vec![
+ ("a".to_string(), "A".to_string()),
+ ("a".to_string(), "AA".to_string()),
+ ("b".to_string(), "B".to_string()),
+ ("c".to_string(), "C".to_string()),
+ ];
+ let bottom_left_projection =
+ projection_exec_with_alias(bottom_left_join, alias_pairs);
+
+ // Join on (c == c1 and b == b1 and a == a1)
+ let join_on = vec![
+ (
+ Column::new_with_schema("c", &schema()).unwrap(),
+ Column::new_with_schema("c1", &right.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("b", &schema()).unwrap(),
+ Column::new_with_schema("b1", &right.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("a", &schema()).unwrap(),
+ Column::new_with_schema("a1", &right.schema()).unwrap(),
+ ),
+ ];
+ let bottom_right_join =
+ hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner);
+
+ // Join on (B == b1 and C == c and AA = a1)
+ let top_join_on = vec![
+ (
+ Column::new_with_schema("B", &bottom_left_projection.schema()).unwrap(),
+ Column::new_with_schema("b1", &bottom_right_join.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("C", &bottom_left_projection.schema()).unwrap(),
+ Column::new_with_schema("c", &bottom_right_join.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("AA", &bottom_left_projection.schema()).unwrap(),
+ Column::new_with_schema("a1", &bottom_right_join.schema()).unwrap(),
+ ),
+ ];
+
+ let top_join = hash_join_exec(
+ bottom_left_projection.clone(),
+ bottom_right_join,
+ &top_join_on,
+ &JoinType::Inner,
+ );
+
+ let predicate: Arc<dyn PhysicalExpr> = binary(
+ col("c", top_join.schema().deref())?,
+ Operator::Gt,
+ lit(1i64),
+ top_join.schema().deref(),
+ )?;
+
+ let filter_top_join: Arc<dyn ExecutionPlan> =
+ Arc::new(FilterExec::try_new(predicate, top_join)?);
+
+ // The bottom joins' join key ordering is adjusted based on the top join. And the top join should not introduce additional RepartitionExec
+ let expected = &[
+ "FilterExec: c@6 > 1",
+ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"B\", index: 2 }, Column { name: \"b1\", index: 6 }), (Column { name: \"C\", index: 3 }, Column { name: \"c\", index: 2 }), (Column { name: \"AA\", index: 1 }, Column { name: \"a1\", index: 5 })]",
+ "ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C]",
+ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"b\", index: 1 }, Column { name: \"b1\", index: 1 }), (Column { name: \"c\", index: 2 }, Column { name: \"c1\", index: 2 }), (Column { name: \"a\", index: 0 }, Column { name: \"a1\", index: 0 })]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 1 }, Column { name: \"c\", index: 2 }, Column { name: \"a\", index: 0 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }, Column { name: \"c1\", index: 2 }, Column { name: \"a1\", index: 0 }], 10)",
+ "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"b\", index: 1 }, Column { name: \"b1\", index: 1 }), (Column { name: \"c\", index: 2 }, Column { name: \"c1\", index: 2 }), (Column { name: \"a\", index: 0 }, Column { name: \"a1\", index: 0 })]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 1 }, Column { name: \"c\", index: 2 }, Column { name: \"a\", index: 0 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }, Column { name: \"c1\", index: 2 }, Column { name: \"a1\", index: 0 }], 10)",
+ "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ];
+ assert_optimized!(expected, filter_top_join);
+ Ok(())
+ }
+
+ #[test]
+ fn reorder_join_keys_to_left_input() -> Result<()> {
+ let left = parquet_exec();
+ let alias_pairs: Vec<(String, String)> = vec![
+ ("a".to_string(), "a1".to_string()),
+ ("b".to_string(), "b1".to_string()),
+ ("c".to_string(), "c1".to_string()),
+ ];
+ let right = projection_exec_with_alias(parquet_exec(), alias_pairs);
+
+ // Join on (a == a1 and b == b1 and c == c1)
+ let join_on = vec![
+ (
+ Column::new_with_schema("a", &schema()).unwrap(),
+ Column::new_with_schema("a1", &right.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("b", &schema()).unwrap(),
+ Column::new_with_schema("b1", &right.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("c", &schema()).unwrap(),
+ Column::new_with_schema("c1", &right.schema()).unwrap(),
+ ),
+ ];
+ let bottom_left_join = ensure_distribution_and_ordering(
+ hash_join_exec(left.clone(), right.clone(), &join_on, &JoinType::Inner),
+ 10,
+ );
+
+ // Projection(a as A, a as AA, b as B, c as C)
+ let alias_pairs: Vec<(String, String)> = vec![
+ ("a".to_string(), "A".to_string()),
+ ("a".to_string(), "AA".to_string()),
+ ("b".to_string(), "B".to_string()),
+ ("c".to_string(), "C".to_string()),
+ ];
+ let bottom_left_projection =
+ projection_exec_with_alias(bottom_left_join, alias_pairs);
+
+ // Join on (c == c1 and b == b1 and a == a1)
+ let join_on = vec![
+ (
+ Column::new_with_schema("c", &schema()).unwrap(),
+ Column::new_with_schema("c1", &right.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("b", &schema()).unwrap(),
+ Column::new_with_schema("b1", &right.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("a", &schema()).unwrap(),
+ Column::new_with_schema("a1", &right.schema()).unwrap(),
+ ),
+ ];
+ let bottom_right_join = ensure_distribution_and_ordering(
+ hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner),
+ 10,
+ );
+
+ // Join on (B == b1 and C == c and AA = a1)
+ let top_join_on = vec![
+ (
+ Column::new_with_schema("B", &bottom_left_projection.schema()).unwrap(),
+ Column::new_with_schema("b1", &bottom_right_join.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("C", &bottom_left_projection.schema()).unwrap(),
+ Column::new_with_schema("c", &bottom_right_join.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("AA", &bottom_left_projection.schema()).unwrap(),
+ Column::new_with_schema("a1", &bottom_right_join.schema()).unwrap(),
+ ),
+ ];
+
+ let join_types = vec![
+ JoinType::Inner,
+ JoinType::Left,
+ JoinType::Right,
+ JoinType::Full,
+ JoinType::LeftSemi,
+ JoinType::LeftAnti,
+ JoinType::RightSemi,
+ JoinType::RightAnti,
+ ];
+
+ for join_type in join_types {
+ let top_join = hash_join_exec(
+ bottom_left_projection.clone(),
+ bottom_right_join.clone(),
+ &top_join_on,
+ &join_type,
+ );
+ let top_join_plan =
+ format!("HashJoinExec: mode=Partitioned, join_type={:?}, on=[(Column {{ name: \"AA\", index: 1 }}, Column {{ name: \"a1\", index: 5 }}), (Column {{ name: \"B\", index: 2 }}, Column {{ name: \"b1\", index: 6 }}), (Column {{ name: \"C\", index: 3 }}, Column {{ name: \"c\", index: 2 }})]", &join_type);
+
+ let reordered = reorder_join_keys_to_inputs(top_join);
+
+ // The top joins' join key ordering is adjusted based on the children inputs.
+ let expected = &[
+ top_join_plan.as_str(),
+ "ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C]",
+ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"a1\", index: 0 }), (Column { name: \"b\", index: 1 }, Column { name: \"b1\", index: 1 }), (Column { name: \"c\", index: 2 }, Column { name: \"c1\", index: 2 })]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }, Column { name: \"b\", index: 1 }, Column { name: \"c\", index: 2 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"a1\", index: 0 }, Column { name: \"b1\", index: 1 }, Column { name: \"c1\", index: 2 }], 10)",
+ "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"c\", index: 2 }, Column { name: \"c1\", index: 2 }), (Column { name: \"b\", index: 1 }, Column { name: \"b1\", index: 1 }), (Column { name: \"a\", index: 0 }, Column { name: \"a1\", index: 0 })]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }, Column { name: \"b\", index: 1 }, Column { name: \"a\", index: 0 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 2 }, Column { name: \"b1\", index: 1 }, Column { name: \"a1\", index: 0 }], 10)",
+ "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ];
+
+ assert_plan_txt!(expected, reordered);
+ }
+
+ Ok(())
+ }
+
+ #[test]
+ fn reorder_join_keys_to_right_input() -> Result<()> {
+ let left = parquet_exec();
+ let alias_pairs: Vec<(String, String)> = vec![
+ ("a".to_string(), "a1".to_string()),
+ ("b".to_string(), "b1".to_string()),
+ ("c".to_string(), "c1".to_string()),
+ ];
+ let right = projection_exec_with_alias(parquet_exec(), alias_pairs);
+
+ // Join on (a == a1 and b == b1)
+ let join_on = vec![
+ (
+ Column::new_with_schema("a", &schema()).unwrap(),
+ Column::new_with_schema("a1", &right.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("b", &schema()).unwrap(),
+ Column::new_with_schema("b1", &right.schema()).unwrap(),
+ ),
+ ];
+ let bottom_left_join = ensure_distribution_and_ordering(
+ hash_join_exec(left.clone(), right.clone(), &join_on, &JoinType::Inner),
+ 10,
+ );
+
+ // Projection(a as A, a as AA, b as B, c as C)
+ let alias_pairs: Vec<(String, String)> = vec![
+ ("a".to_string(), "A".to_string()),
+ ("a".to_string(), "AA".to_string()),
+ ("b".to_string(), "B".to_string()),
+ ("c".to_string(), "C".to_string()),
+ ];
+ let bottom_left_projection =
+ projection_exec_with_alias(bottom_left_join, alias_pairs);
+
+ // Join on (c == c1 and b == b1 and a == a1)
+ let join_on = vec![
+ (
+ Column::new_with_schema("c", &schema()).unwrap(),
+ Column::new_with_schema("c1", &right.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("b", &schema()).unwrap(),
+ Column::new_with_schema("b1", &right.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("a", &schema()).unwrap(),
+ Column::new_with_schema("a1", &right.schema()).unwrap(),
+ ),
+ ];
+ let bottom_right_join = ensure_distribution_and_ordering(
+ hash_join_exec(left, right.clone(), &join_on, &JoinType::Inner),
+ 10,
+ );
+
+ // Join on (B == b1 and C == c and AA = a1)
+ let top_join_on = vec![
+ (
+ Column::new_with_schema("B", &bottom_left_projection.schema()).unwrap(),
+ Column::new_with_schema("b1", &bottom_right_join.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("C", &bottom_left_projection.schema()).unwrap(),
+ Column::new_with_schema("c", &bottom_right_join.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("AA", &bottom_left_projection.schema()).unwrap(),
+ Column::new_with_schema("a1", &bottom_right_join.schema()).unwrap(),
+ ),
+ ];
+
+ let join_types = vec![
+ JoinType::Inner,
+ JoinType::Left,
+ JoinType::Right,
+ JoinType::Full,
+ JoinType::LeftSemi,
+ JoinType::LeftAnti,
+ JoinType::RightSemi,
+ JoinType::RightAnti,
+ ];
+
+ for join_type in join_types {
+ let top_join = hash_join_exec(
+ bottom_left_projection.clone(),
+ bottom_right_join.clone(),
+ &top_join_on,
+ &join_type,
+ );
+ let top_join_plan =
+ format!("HashJoinExec: mode=Partitioned, join_type={:?}, on=[(Column {{ name: \"C\", index: 3 }}, Column {{ name: \"c\", index: 2 }}), (Column {{ name: \"B\", index: 2 }}, Column {{ name: \"b1\", index: 6 }}), (Column {{ name: \"AA\", index: 1 }}, Column {{ name: \"a1\", index: 5 }})]", &join_type);
+
+ let reordered = reorder_join_keys_to_inputs(top_join);
+
+ // The top joins' join key ordering is adjusted based on the children inputs.
+ let expected = &[
+ top_join_plan.as_str(),
+ "ProjectionExec: expr=[a@0 as A, a@0 as AA, b@1 as B, c@2 as C]",
+ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"a\", index: 0 }, Column { name: \"a1\", index: 0 }), (Column { name: \"b\", index: 1 }, Column { name: \"b1\", index: 1 })]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }, Column { name: \"b\", index: 1 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"a1\", index: 0 }, Column { name: \"b1\", index: 1 }], 10)",
+ "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"c\", index: 2 }, Column { name: \"c1\", index: 2 }), (Column { name: \"b\", index: 1 }, Column { name: \"b1\", index: 1 }), (Column { name: \"a\", index: 0 }, Column { name: \"a1\", index: 0 })]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }, Column { name: \"b\", index: 1 }, Column { name: \"a\", index: 0 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 2 }, Column { name: \"b1\", index: 1 }, Column { name: \"a1\", index: 0 }], 10)",
+ "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ];
+
+ assert_plan_txt!(expected, reordered);
+ }
+
+ Ok(())
+ }
+
+ #[test]
+ fn multi_smj_joins() -> Result<()> {
+ let left = parquet_exec();
+ let alias_pairs: Vec<(String, String)> = vec![
+ ("a".to_string(), "a1".to_string()),
+ ("b".to_string(), "b1".to_string()),
+ ("c".to_string(), "c1".to_string()),
+ ("d".to_string(), "d1".to_string()),
+ ("e".to_string(), "e1".to_string()),
+ ];
+ let right = projection_exec_with_alias(parquet_exec(), alias_pairs);
+
+ // SortMergeJoin does not support RightSemi and RightAnti join now
+ let join_types = vec![
+ JoinType::Inner,
+ JoinType::Left,
+ JoinType::Right,
+ JoinType::Full,
+ JoinType::LeftSemi,
+ JoinType::LeftAnti,
+ ];
+
+ // Join on (a == b1)
+ let join_on = vec![(
+ Column::new_with_schema("a", &schema()).unwrap(),
+ Column::new_with_schema("b1", &right.schema()).unwrap(),
+ )];
+
+ for join_type in join_types {
+ let join =
+ sort_merge_join_exec(left.clone(), right.clone(), &join_on, &join_type);
+ let join_plan =
+ format!("SortMergeJoin: join_type={}, on=[(Column {{ name: \"a\", index: 0 }}, Column {{ name: \"b1\", index: 1 }})]", join_type);
+
+ // Top join on (a == c)
+ let top_join_on = vec![(
+ Column::new_with_schema("a", &join.schema()).unwrap(),
+ Column::new_with_schema("c", &schema()).unwrap(),
+ )];
+ let top_join = sort_merge_join_exec(
+ join.clone(),
+ parquet_exec(),
+ &top_join_on,
+ &join_type,
+ );
+ let top_join_plan =
+ format!("SortMergeJoin: join_type={}, on=[(Column {{ name: \"a\", index: 0 }}, Column {{ name: \"c\", index: 2 }})]", join_type);
+
+ let expected = match join_type {
+ // Should include 3 RepartitionExecs 3 SortExecs
+ JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti =>
+ vec![
+ top_join_plan.as_str(),
+ join_plan.as_str(),
+ "SortExec: [a@0 ASC]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "SortExec: [b1@1 ASC]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)",
+ "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "SortExec: [c@2 ASC]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ],
+ // Should include 4 RepartitionExecs
+ _ => vec![
+ top_join_plan.as_str(),
+ "SortExec: [a@0 ASC]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)",
+ join_plan.as_str(),
+ "SortExec: [a@0 ASC]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "SortExec: [b1@1 ASC]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)",
+ "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "SortExec: [c@2 ASC]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ],
+ };
+ assert_optimized!(expected, top_join);
+
+ match join_type {
+ JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
+ // This time we use (b1 == c) for top join
+ // Join on (b1 == c)
+ let top_join_on = vec![(
+ Column::new_with_schema("b1", &join.schema()).unwrap(),
+ Column::new_with_schema("c", &schema()).unwrap(),
+ )];
+ let top_join = sort_merge_join_exec(
+ join,
+ parquet_exec(),
+ &top_join_on,
+ &join_type,
+ );
+ let top_join_plan =
+ format!("SortMergeJoin: join_type={}, on=[(Column {{ name: \"b1\", index: 6 }}, Column {{ name: \"c\", index: 2 }})]", join_type);
+
+ let expected = match join_type {
+ // Should include 3 RepartitionExecs and 3 SortExecs
+ JoinType::Inner | JoinType::Right => vec![
+ top_join_plan.as_str(),
+ join_plan.as_str(),
+ "SortExec: [a@0 ASC]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "SortExec: [b1@1 ASC]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)",
+ "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "SortExec: [c@2 ASC]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ],
+ // Should include 4 RepartitionExecs and 4 SortExecs
+ _ => vec![
+ top_join_plan.as_str(),
+ "SortExec: [b1@6 ASC]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 6 }], 10)",
+ join_plan.as_str(),
+ "SortExec: [a@0 ASC]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"a\", index: 0 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "SortExec: [b1@1 ASC]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 1 }], 10)",
+ "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "SortExec: [c@2 ASC]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"c\", index: 2 }], 10)",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ],
+ };
+ assert_optimized!(expected, top_join);
+ }
+ _ => {}
+ }
+ }
+
+ Ok(())
+ }
+
+ #[test]
+ fn smj_join_key_ordering() -> Result<()> {
+ // group by (a as a1, b as b1)
+ let left = aggregate_exec_with_alias(
+ parquet_exec(),
+ vec![
+ ("a".to_string(), "a1".to_string()),
+ ("b".to_string(), "b1".to_string()),
+ ],
+ );
+ //Projection(a1 as a3, b1 as b3)
+ let alias_pairs: Vec<(String, String)> = vec![
+ ("a1".to_string(), "a3".to_string()),
+ ("b1".to_string(), "b3".to_string()),
+ ];
+ let left = projection_exec_with_alias(left, alias_pairs);
+
+ // group by (b, a)
+ let right = aggregate_exec_with_alias(
+ parquet_exec(),
+ vec![
+ ("b".to_string(), "b".to_string()),
+ ("a".to_string(), "a".to_string()),
+ ],
+ );
+
+ //Projection(a as a2, b as b2)
+ let alias_pairs: Vec<(String, String)> = vec![
+ ("a".to_string(), "a2".to_string()),
+ ("b".to_string(), "b2".to_string()),
+ ];
+ let right = projection_exec_with_alias(right, alias_pairs);
+
+ // Join on (b3 == b2 && a3 == a2)
+ let join_on = vec![
+ (
+ Column::new_with_schema("b3", &left.schema()).unwrap(),
+ Column::new_with_schema("b2", &right.schema()).unwrap(),
+ ),
+ (
+ Column::new_with_schema("a3", &left.schema()).unwrap(),
+ Column::new_with_schema("a2", &right.schema()).unwrap(),
+ ),
+ ];
+ let join = sort_merge_join_exec(left, right.clone(), &join_on, &JoinType::Inner);
+
+ // Only two RepartitionExecs added
+ let expected = &[
+ "SortMergeJoin: join_type=Inner, on=[(Column { name: \"b3\", index: 1 }, Column { name: \"b2\", index: 1 }), (Column { name: \"a3\", index: 0 }, Column { name: \"a2\", index: 0 })]",
+ "SortExec: [b3@1 ASC,a3@0 ASC]",
+ "ProjectionExec: expr=[a1@0 as a3, b1@1 as b3]",
+ "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]",
+ "AggregateExec: mode=FinalPartitioned, gby=[b1@1 as b1, a1@0 as a1], aggr=[]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 0 }, Column { name: \"a1\", index: 1 }], 10)",
+ "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ "SortExec: [b2@1 ASC,a2@0 ASC]",
+ "ProjectionExec: expr=[a@1 as a2, b@0 as b2]",
+ "AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]",
+ "RepartitionExec: partitioning=Hash([Column { name: \"b\", index: 0 }, Column { name: \"a\", index: 1 }], 10)",
+ "AggregateExec: mode=Partial, gby=[b@1 as b, a@0 as a], aggr=[]",
+ "ParquetExec: limit=None, partitions=[x], projection=[a, b, c, d, e]",
+ ];
+ assert_optimized!(expected, join);
+ Ok(())
+ }
+}
diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs
index 55550bcd2..5ecb9cd37 100644
--- a/datafusion/core/src/physical_optimizer/mod.rs
+++ b/datafusion/core/src/physical_optimizer/mod.rs
@@ -20,6 +20,7 @@
pub mod aggregate_statistics;
pub mod coalesce_batches;
+pub mod enforcement;
pub mod hash_build_probe_order;
pub mod merge_exec;
pub mod optimizer;
diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs
index 6fda8f77b..43e75e352 100644
--- a/datafusion/core/src/physical_plan/aggregates/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/mod.rs
@@ -49,6 +49,7 @@ use crate::physical_plan::aggregates::row_hash::GroupedHashAggregateStreamV2;
use crate::physical_plan::EquivalenceProperties;
pub use datafusion_expr::AggregateFunction;
use datafusion_physical_expr::aggregate::row_accumulator::RowAccumulator;
+use datafusion_physical_expr::equivalence::project_equivalence_properties;
pub use datafusion_physical_expr::expressions::create_aggregate_expr;
use datafusion_physical_expr::normalize_out_expr_with_alias_schema;
use datafusion_row::{row_supported, RowType};
@@ -153,19 +154,19 @@ impl PhysicalGroupBy {
#[derive(Debug)]
pub struct AggregateExec {
/// Aggregation mode (full, partial)
- mode: AggregateMode,
+ pub(crate) mode: AggregateMode,
/// Group by expressions
- group_by: PhysicalGroupBy,
+ pub(crate) group_by: PhysicalGroupBy,
/// Aggregate expressions
- aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+ pub(crate) aggr_expr: Vec<Arc<dyn AggregateExpr>>,
/// Input plan, could be a partial aggregate or the input to the aggregate
- input: Arc<dyn ExecutionPlan>,
+ pub(crate) input: Arc<dyn ExecutionPlan>,
/// Schema after the aggregate is applied
schema: SchemaRef,
/// Input schema before any aggregation is applied. For partial aggregate this will be the
/// same as input.schema() but for the final aggregate it will be the same as the input
/// to the partial aggregate
- input_schema: SchemaRef,
+ pub(crate) input_schema: SchemaRef,
/// The alias map used to normalize out expressions like Partitioning and PhysicalSortExpr
/// The key is the column from the input schema and the values are the columns from the output schema
alias_map: HashMap<Column, Vec<Column>>,
@@ -315,10 +316,13 @@ impl ExecutionPlan for AggregateExec {
}
fn equivalence_properties(&self) -> EquivalenceProperties {
- let mut input_equivalence_properties = self.input.equivalence_properties();
- input_equivalence_properties.merge_properties_with_alias(&self.alias_map);
- input_equivalence_properties.truncate_properties_not_in_schema(&self.schema);
- input_equivalence_properties
+ let mut new_properties = EquivalenceProperties::new(self.schema());
+ project_equivalence_properties(
+ self.input.equivalence_properties(),
+ &self.alias_map,
+ &mut new_properties,
+ );
+ new_properties
}
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
diff --git a/datafusion/core/src/physical_plan/joins/cross_join.rs b/datafusion/core/src/physical_plan/joins/cross_join.rs
index a71e06cce..170153e07 100644
--- a/datafusion/core/src/physical_plan/joins/cross_join.rs
+++ b/datafusion/core/src/physical_plan/joins/cross_join.rs
@@ -29,9 +29,9 @@ use arrow::record_batch::RecordBatch;
use crate::execution::context::TaskContext;
use crate::physical_plan::{
coalesce_batches::concat_batches, coalesce_partitions::CoalescePartitionsExec,
- ColumnStatistics, DisplayFormatType, EquivalenceProperties, ExecutionPlan,
- Partitioning, PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream,
- Statistics,
+ ColumnStatistics, DisplayFormatType, Distribution, EquivalenceProperties,
+ ExecutionPlan, Partitioning, PhysicalSortExpr, RecordBatchStream,
+ SendableRecordBatchStream, Statistics,
};
use crate::{error::Result, scalar::ScalarValue};
use async_trait::async_trait;
@@ -51,9 +51,9 @@ type JoinLeftData = RecordBatch;
#[derive(Debug)]
pub struct CrossJoinExec {
/// left (build) side which gets loaded in memory
- left: Arc<dyn ExecutionPlan>,
+ pub(crate) left: Arc<dyn ExecutionPlan>,
/// right (probe) side which are combined with left side
- right: Arc<dyn ExecutionPlan>,
+ pub(crate) right: Arc<dyn ExecutionPlan>,
/// The schema once the join is applied
schema: SchemaRef,
/// Build-side data
@@ -110,7 +110,13 @@ async fn load_left_input(
let start = Instant::now();
// merge all left parts into a single stream
- let merge = CoalescePartitionsExec::new(left.clone());
+ let merge = {
+ if left.output_partitioning().partition_count() != 1 {
+ Arc::new(CoalescePartitionsExec::new(left.clone()))
+ } else {
+ left.clone()
+ }
+ };
let stream = merge.execute(0, context)?;
// Load all batches and count the rows
@@ -156,6 +162,13 @@ impl ExecutionPlan for CrossJoinExec {
)?))
}
+ fn required_input_distribution(&self) -> Vec<Distribution> {
+ vec![
+ Distribution::SinglePartition,
+ Distribution::UnspecifiedDistribution,
+ ]
+ }
+
// TODO optimize CrossJoin implementation to generate M * N partitions
fn output_partitioning(&self) -> Partitioning {
let left_columns_len = self.left.schema().fields.len();
@@ -176,6 +189,7 @@ impl ExecutionPlan for CrossJoinExec {
self.left.equivalence_properties(),
self.right.equivalence_properties(),
left_columns_len,
+ self.schema(),
)
}
diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs
index 07b4d6f4a..597e5e298 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join.rs
@@ -118,15 +118,15 @@ type JoinLeftData = (JoinHashMap, RecordBatch);
#[derive(Debug)]
pub struct HashJoinExec {
/// left (build) side which gets hashed
- left: Arc<dyn ExecutionPlan>,
+ pub(crate) left: Arc<dyn ExecutionPlan>,
/// right (probe) side which are filtered by the hash table
- right: Arc<dyn ExecutionPlan>,
+ pub(crate) right: Arc<dyn ExecutionPlan>,
/// Set of common columns used to join on
- on: Vec<(Column, Column)>,
+ pub(crate) on: Vec<(Column, Column)>,
/// Filters which are applied while finding matching rows
- filter: Option<JoinFilter>,
+ pub(crate) filter: Option<JoinFilter>,
/// How the join is performed
- join_type: JoinType,
+ pub(crate) join_type: JoinType,
/// The schema once the join is applied
schema: SchemaRef,
/// Build-side data
@@ -134,13 +134,13 @@ pub struct HashJoinExec {
/// Shares the `RandomState` for the hashing algorithm
random_state: RandomState,
/// Partitioning mode to use
- mode: PartitionMode,
+ pub(crate) mode: PartitionMode,
/// Execution metrics
metrics: ExecutionPlanMetricsSet,
/// Information of index and left / right placement of columns
column_indices: Vec<ColumnIndex>,
/// If null_equals_null is true, null == null else null != null
- null_equals_null: bool,
+ pub(crate) null_equals_null: bool,
}
/// Metrics for HashJoinExec
@@ -337,6 +337,7 @@ impl ExecutionPlan for HashJoinExec {
self.right.equivalence_properties(),
left_columns_len,
self.on(),
+ self.schema(),
)
}
@@ -447,9 +448,14 @@ async fn collect_left_input(
) -> Result<JoinLeftData> {
let schema = left.schema();
let start = Instant::now();
-
// merge all left parts into a single stream
- let merge = CoalescePartitionsExec::new(left);
+ let merge = {
+ if left.output_partitioning().partition_count() != 1 {
+ Arc::new(CoalescePartitionsExec::new(left))
+ } else {
+ left
+ }
+ };
let stream = merge.execute(0, context)?;
// This operation performs 2 steps at once:
diff --git a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs
index 44771ba4c..5ea1b22a4 100644
--- a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs
+++ b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs
@@ -59,13 +59,13 @@ use datafusion_physical_expr::rewrite::TreeNodeRewritable;
#[derive(Debug)]
pub struct SortMergeJoinExec {
/// Left sorted joining execution plan
- left: Arc<dyn ExecutionPlan>,
+ pub(crate) left: Arc<dyn ExecutionPlan>,
/// Right sorting joining execution plan
- right: Arc<dyn ExecutionPlan>,
+ pub(crate) right: Arc<dyn ExecutionPlan>,
/// Set of common columns used to join on
- on: JoinOn,
+ pub(crate) on: JoinOn,
/// How the join is performed
- join_type: JoinType,
+ pub(crate) join_type: JoinType,
/// The schema once the join is applied
schema: SchemaRef,
/// Execution metrics
@@ -77,9 +77,9 @@ pub struct SortMergeJoinExec {
/// The output ordering
output_ordering: Option<Vec<PhysicalSortExpr>>,
/// Sort options of join columns used in sorting left and right execution plans
- sort_options: Vec<SortOptions>,
+ pub(crate) sort_options: Vec<SortOptions>,
/// If null_equals_null is true, null == null else null != null
- null_equals_null: bool,
+ pub(crate) null_equals_null: bool,
}
impl SortMergeJoinExec {
@@ -258,6 +258,7 @@ impl ExecutionPlan for SortMergeJoinExec {
self.right.equivalence_properties(),
left_columns_len,
self.on(),
+ self.schema(),
)
}
diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs
index 905e59de9..cc71ee6af 100644
--- a/datafusion/core/src/physical_plan/joins/utils.rs
+++ b/datafusion/core/src/physical_plan/joins/utils.rs
@@ -20,6 +20,7 @@
use crate::error::{DataFusionError, Result};
use crate::logical_expr::JoinType;
use crate::physical_plan::expressions::Column;
+use crate::physical_plan::SchemaRef;
use arrow::datatypes::{Field, Schema};
use arrow::error::ArrowError;
use datafusion_common::ScalarValue;
@@ -143,10 +144,12 @@ pub fn combine_join_equivalence_properties(
right_properties: EquivalenceProperties,
left_columns_len: usize,
on: &[(Column, Column)],
+ schema: SchemaRef,
) -> EquivalenceProperties {
- let mut new_properties = match join_type {
+ let mut new_properties = EquivalenceProperties::new(schema);
+ match join_type {
JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
- let mut left_properties = left_properties;
+ new_properties.extend(left_properties.classes().to_vec());
let new_right_properties = right_properties
.classes()
.iter()
@@ -166,12 +169,15 @@ pub fn combine_join_equivalence_properties(
})
.collect::<Vec<_>>();
- left_properties.extend(new_right_properties);
- left_properties
+ new_properties.extend(new_right_properties);
}
- JoinType::LeftSemi | JoinType::LeftAnti => left_properties,
- JoinType::RightSemi | JoinType::RightAnti => right_properties,
- };
+ JoinType::LeftSemi | JoinType::LeftAnti => {
+ new_properties.extend(left_properties.classes().to_vec())
+ }
+ JoinType::RightSemi | JoinType::RightAnti => {
+ new_properties.extend(right_properties.classes().to_vec())
+ }
+ }
if join_type == JoinType::Inner {
on.iter().for_each(|(column1, column2)| {
@@ -188,8 +194,10 @@ pub fn cross_join_equivalence_properties(
left_properties: EquivalenceProperties,
right_properties: EquivalenceProperties,
left_columns_len: usize,
+ schema: SchemaRef,
) -> EquivalenceProperties {
- let mut left_properties = left_properties;
+ let mut new_properties = EquivalenceProperties::new(schema);
+ new_properties.extend(left_properties.classes().to_vec());
let new_right_properties = right_properties
.classes()
.iter()
@@ -204,8 +212,8 @@ pub fn cross_join_equivalence_properties(
EquivalentClass::new(new_head, new_others)
})
.collect::<Vec<_>>();
- left_properties.extend(new_right_properties);
- left_properties
+ new_properties.extend(new_right_properties);
+ new_properties
}
/// Used in ColumnIndex to distinguish which side the index is for
diff --git a/datafusion/core/src/physical_plan/mod.rs b/datafusion/core/src/physical_plan/mod.rs
index 87df26781..0bbb8bcab 100644
--- a/datafusion/core/src/physical_plan/mod.rs
+++ b/datafusion/core/src/physical_plan/mod.rs
@@ -125,11 +125,7 @@ pub trait ExecutionPlan: Debug + Send + Sync {
/// Specifies the data distribution requirements for all the
/// children for this operator, By default it's [[Distribution::UnspecifiedDistribution]] for each child,
fn required_input_distribution(&self) -> Vec<Distribution> {
- if !self.children().is_empty() {
- vec![Distribution::UnspecifiedDistribution; self.children().len()]
- } else {
- vec![Distribution::UnspecifiedDistribution]
- }
+ vec![Distribution::UnspecifiedDistribution; self.children().len()]
}
/// Specifies the ordering requirements for all the
@@ -197,7 +193,7 @@ pub trait ExecutionPlan: Debug + Send + Sync {
/// Get the EquivalenceProperties within the plan
fn equivalence_properties(&self) -> EquivalenceProperties {
- EquivalenceProperties::new()
+ EquivalenceProperties::new(self.schema())
}
/// Get a list of child execution plans that provide the input for this plan. The returned list
@@ -477,6 +473,66 @@ impl Partitioning {
RoundRobinBatch(n) | Hash(_, n) | UnknownPartitioning(n) => *n,
}
}
+
+ /// Returns true when the guarantees made by this [[Partitioning]] are sufficient to
+ /// satisfy the partitioning scheme mandated by the `required` [[Distribution]]
+ pub fn satisfy<F: FnOnce() -> EquivalenceProperties>(
+ &self,
+ required: Distribution,
+ equal_properties: F,
+ ) -> bool {
+ match required {
+ Distribution::UnspecifiedDistribution => true,
+ Distribution::SinglePartition if self.partition_count() == 1 => true,
+ Distribution::HashPartitioned(required_exprs) => {
+ match self {
+ // Here we do not check the partition count for hash partitioning and assumes the partition count
+ // and hash functions in the system are the same. In future if we plan to support storage partition-wise joins,
+ // then we need to have the partition count and hash functions validation.
+ Partitioning::Hash(partition_exprs, _) => {
+ let fast_match =
+ expr_list_eq_strict_order(&required_exprs, partition_exprs);
+ // If the required exprs do not match, need to leverage the eq_properties provided by the child
+ // and normalize both exprs based on the eq_properties
+ if !fast_match {
+ let eq_properties = equal_properties();
+ let eq_classes = eq_properties.classes();
+ if !eq_classes.is_empty() {
+ let normalized_required_exprs = required_exprs
+ .iter()
+ .map(|e| {
+ normalize_expr_with_equivalence_properties(
+ e.clone(),
+ eq_classes,
+ )
+ })
+ .collect::<Vec<_>>();
+ let normalized_partition_exprs = partition_exprs
+ .iter()
+ .map(|e| {
+ normalize_expr_with_equivalence_properties(
+ e.clone(),
+ eq_classes,
+ )
+ })
+ .collect::<Vec<_>>();
+ expr_list_eq_strict_order(
+ &normalized_required_exprs,
+ &normalized_partition_exprs,
+ )
+ } else {
+ fast_match
+ }
+ } else {
+ fast_match
+ }
+ }
+ _ => false,
+ }
+ }
+ _ => false,
+ }
+ }
}
impl PartialEq for Partitioning {
@@ -508,10 +564,27 @@ pub enum Distribution {
HashPartitioned(Vec<Arc<dyn PhysicalExpr>>),
}
-use datafusion_physical_expr::expr_list_eq_strict_order;
+impl Distribution {
+ /// Creates a Partitioning for this Distribution to satisfy itself
+ pub fn create_partitioning(&self, partition_count: usize) -> Partitioning {
+ match self {
+ Distribution::UnspecifiedDistribution => {
+ Partitioning::UnknownPartitioning(partition_count)
+ }
+ Distribution::SinglePartition => Partitioning::UnknownPartitioning(1),
+ Distribution::HashPartitioned(expr) => {
+ Partitioning::Hash(expr.clone(), partition_count)
+ }
+ }
+ }
+}
+
use datafusion_physical_expr::expressions::Column;
pub use datafusion_physical_expr::window::WindowExpr;
use datafusion_physical_expr::EquivalenceProperties;
+use datafusion_physical_expr::{
+ expr_list_eq_strict_order, normalize_expr_with_equivalence_properties,
+};
pub use datafusion_physical_expr::{AggregateExpr, PhysicalExpr};
/// Applies an optional projection to a [`SchemaRef`], returning the
@@ -571,6 +644,7 @@ pub mod metrics;
pub mod planner;
pub mod projection;
pub mod repartition;
+pub mod rewrite;
pub mod sorts;
pub mod stream;
pub mod udaf;
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 88a43e111..729649e7a 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -539,19 +539,6 @@ impl DefaultPhysicalPlanner {
vec![]
};
- let input_exec = if can_repartition {
- Arc::new(RepartitionExec::try_new(
- input_exec,
- Partitioning::Hash(
- physical_partition_keys.clone(),
- session_state.config.target_partitions,
- ),
- )?)
- } else {
- input_exec
- };
-
- // add a sort phase
let get_sort_keys = |expr: &Expr| match expr {
Expr::WindowFunction {
ref partition_by,
@@ -609,16 +596,6 @@ impl DefaultPhysicalPlanner {
Some(sort_keys)
};
- let input_exec = match physical_sort_keys.clone() {
- None => input_exec,
- Some(sort_exprs) => {
- if can_repartition {
- Arc::new(SortExec::new_with_partitioning(sort_exprs, input_exec, true, None))
- } else {
- Arc::new(SortExec::try_new(sort_exprs, input_exec, None)?)
- }
- },
- };
let physical_input_schema = input_exec.schema();
let window_expr = window_expr
.iter()
@@ -688,16 +665,8 @@ impl DefaultPhysicalPlanner {
Arc<dyn ExecutionPlan>,
AggregateMode,
) = if can_repartition {
- // Divide partial hash aggregates into multiple partitions by hash key
- let hash_repartition = Arc::new(RepartitionExec::try_new(
- initial_aggr,
- Partitioning::Hash(
- final_group.clone(),
- session_state.config.target_partitions,
- ),
- )?);
- // Combine hash aggregates within the partition
- (hash_repartition, AggregateMode::FinalPartitioned)
+ // construct a second aggregation with 'AggregateMode::FinalPartitioned'
+ (initial_aggr, AggregateMode::FinalPartitioned)
} else {
// construct a second aggregation, keeping the final column name equal to the
// first aggregation and the expressions corresponding to the respective aggregate
@@ -965,32 +934,10 @@ impl DefaultPhysicalPlanner {
if session_state.config.target_partitions > 1
&& session_state.config.repartition_joins
{
- let (left_expr, right_expr) = join_on
- .iter()
- .map(|(l, r)| {
- (
- Arc::new(l.clone()) as Arc<dyn PhysicalExpr>,
- Arc::new(r.clone()) as Arc<dyn PhysicalExpr>,
- )
- })
- .unzip();
-
// Use hash partition by default to parallelize hash joins
Ok(Arc::new(HashJoinExec::try_new(
- Arc::new(RepartitionExec::try_new(
- physical_left,
- Partitioning::Hash(
- left_expr,
- session_state.config.target_partitions,
- ),
- )?),
- Arc::new(RepartitionExec::try_new(
- physical_right,
- Partitioning::Hash(
- right_expr,
- session_state.config.target_partitions,
- ),
- )?),
+ physical_left,
+ physical_right,
join_on,
join_filter,
join_type,
diff --git a/datafusion/core/src/physical_plan/projection.rs b/datafusion/core/src/physical_plan/projection.rs
index 2b6297f8c..692880a83 100644
--- a/datafusion/core/src/physical_plan/projection.rs
+++ b/datafusion/core/src/physical_plan/projection.rs
@@ -40,6 +40,7 @@ use super::expressions::{Column, PhysicalSortExpr};
use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use super::{RecordBatchStream, SendableRecordBatchStream, Statistics};
use crate::execution::context::TaskContext;
+use datafusion_physical_expr::equivalence::project_equivalence_properties;
use datafusion_physical_expr::normalize_out_expr_with_alias_schema;
use futures::stream::Stream;
use futures::stream::StreamExt;
@@ -48,7 +49,7 @@ use futures::stream::StreamExt;
#[derive(Debug)]
pub struct ProjectionExec {
/// The projection expressions stored as tuples of (expression, output column name)
- expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
+ pub(crate) expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
/// The schema once the projection has been applied to the input
schema: SchemaRef,
/// The input plan
@@ -191,15 +192,14 @@ impl ExecutionPlan for ProjectionExec {
true
}
- // Equivalence properties need to be adjusted after the Projection.
- // 1) Add Alias, Alias can introduce additional equivalence properties,
- // For example: Projection(a, a as a1, a as a2)
- // 2) Truncate the properties that are not in the schema of the Projection
fn equivalence_properties(&self) -> EquivalenceProperties {
- let mut input_equivalence_properties = self.input.equivalence_properties();
- input_equivalence_properties.merge_properties_with_alias(&self.alias_map);
- input_equivalence_properties.truncate_properties_not_in_schema(&self.schema);
- input_equivalence_properties
+ let mut new_properties = EquivalenceProperties::new(self.schema());
+ project_equivalence_properties(
+ self.input.equivalence_properties(),
+ &self.alias_map,
+ &mut new_properties,
+ );
+ new_properties
}
fn with_new_children(
diff --git a/datafusion/core/src/physical_plan/rewrite.rs b/datafusion/core/src/physical_plan/rewrite.rs
new file mode 100644
index 000000000..1dfc36eb1
--- /dev/null
+++ b/datafusion/core/src/physical_plan/rewrite.rs
@@ -0,0 +1,165 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Trait to make Executionplan rewritable
+
+use crate::physical_plan::with_new_children_if_necessary;
+use crate::physical_plan::ExecutionPlan;
+use datafusion_common::Result;
+
+use std::sync::Arc;
+
+/// a Trait for marking tree node types that are rewritable
+pub trait TreeNodeRewritable: Clone {
+ /// Transform the tree node using the given [TreeNodeRewriter]
+ /// It performs a depth first walk of an node and its children.
+ ///
+ /// For an node tree such as
+ /// ```text
+ /// ParentNode
+ /// left: ChildNode1
+ /// right: ChildNode2
+ /// ```
+ ///
+ /// The nodes are visited using the following order
+ /// ```text
+ /// pre_visit(ParentNode)
+ /// pre_visit(ChildNode1)
+ /// mutate(ChildNode1)
+ /// pre_visit(ChildNode2)
+ /// mutate(ChildNode2)
+ /// mutate(ParentNode)
+ /// ```
+ ///
+ /// If an Err result is returned, recursion is stopped immediately
+ ///
+ /// If [`false`] is returned on a call to pre_visit, no
+ /// children of that node are visited, nor is mutate
+ /// called on that node
+ ///
+ fn transform_using<R: TreeNodeRewriter<Self>>(
+ self,
+ rewriter: &mut R,
+ ) -> Result<Self> {
+ let need_mutate = match rewriter.pre_visit(&self)? {
+ RewriteRecursion::Mutate => return rewriter.mutate(self),
+ RewriteRecursion::Stop => return Ok(self),
+ RewriteRecursion::Continue => true,
+ RewriteRecursion::Skip => false,
+ };
+
+ let after_op_children =
+ self.map_children(|node| node.transform_using(rewriter))?;
+
+ // now rewrite this node itself
+ if need_mutate {
+ rewriter.mutate(after_op_children)
+ } else {
+ Ok(after_op_children)
+ }
+ }
+
+ /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree.
+ /// When `op` does not apply to a given node, it is left unchanged.
+ /// The default tree traversal direction is transform_up(Postorder Traversal).
+ fn transform<F>(self, op: &F) -> Result<Self>
+ where
+ F: Fn(Self) -> Option<Self>,
+ {
+ self.transform_up(op)
+ }
+
+ /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its
+ /// children(Preorder Traversal).
+ /// When the `op` does not apply to a given node, it is left unchanged.
+ fn transform_down<F>(self, op: &F) -> Result<Self>
+ where
+ F: Fn(Self) -> Option<Self>,
+ {
+ let node_cloned = self.clone();
+ let after_op = match op(node_cloned) {
+ Some(value) => value,
+ None => self,
+ };
+ after_op.map_children(|node| node.transform_down(op))
+ }
+
+ /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its
+ /// children and then itself(Postorder Traversal).
+ /// When the `op` does not apply to a given node, it is left unchanged.
+ fn transform_up<F>(self, op: &F) -> Result<Self>
+ where
+ F: Fn(Self) -> Option<Self>,
+ {
+ let after_op_children = self.map_children(|node| node.transform_up(op))?;
+
+ let after_op_children_clone = after_op_children.clone();
+ let new_node = match op(after_op_children) {
+ Some(value) => value,
+ None => after_op_children_clone,
+ };
+ Ok(new_node)
+ }
+
+ /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder)
+ fn map_children<F>(self, transform: F) -> Result<Self>
+ where
+ F: FnMut(Self) -> Result<Self>;
+}
+
+/// Trait for potentially recursively transform an [`TreeNodeRewritable`] node
+/// tree. When passed to `TreeNodeRewritable::transform_using`, `TreeNodeRewriter::mutate` is
+/// invoked recursively on all nodes of a tree.
+pub trait TreeNodeRewriter<N: TreeNodeRewritable>: Sized {
+ /// Invoked before (Preorder) any children of `node` are rewritten /
+ /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)`
+ fn pre_visit(&mut self, _node: &N) -> Result<RewriteRecursion> {
+ Ok(RewriteRecursion::Continue)
+ }
+
+ /// Invoked after (Postorder) all children of `node` have been mutated and
+ /// returns a potentially modified node.
+ fn mutate(&mut self, node: N) -> Result<N>;
+}
+
+/// Controls how the [TreeNodeRewriter] recursion should proceed.
+#[allow(dead_code)]
+pub enum RewriteRecursion {
+ /// Continue rewrite / visit this node tree.
+ Continue,
+ /// Call 'op' immediately and return.
+ Mutate,
+ /// Do not rewrite / visit the children of this node.
+ Stop,
+ /// Keep recursive but skip apply op on this node
+ Skip,
+}
+
+impl TreeNodeRewritable for Arc<dyn ExecutionPlan> {
+ fn map_children<F>(self, transform: F) -> Result<Self>
+ where
+ F: FnMut(Self) -> Result<Self>,
+ {
+ if !self.children().is_empty() {
+ let new_children: Result<Vec<_>> =
+ self.children().into_iter().map(transform).collect();
+ with_new_children_if_necessary(self, new_children?)
+ } else {
+ Ok(self)
+ }
+ }
+}
diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs
index 76ad0afb1..248c4570d 100644
--- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs
+++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs
@@ -36,7 +36,7 @@ use arrow::{
};
use futures::stream::Stream;
use futures::{ready, StreamExt};
-use log::warn;
+use log::debug;
use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
@@ -136,7 +136,7 @@ impl ExecutionPlan for WindowAggExec {
fn required_input_distribution(&self) -> Vec<Distribution> {
if self.partition_keys.is_empty() {
- warn!("No partition defined for WindowAggExec!!!");
+ debug!("No partition defined for WindowAggExec!!!");
vec![Distribution::SinglePartition]
} else {
//TODO support PartitionCollections if there is no common partition columns in the window_expr
diff --git a/datafusion/physical-expr/src/equivalence.rs b/datafusion/physical-expr/src/equivalence.rs
index 411a492a5..13cafaba4 100644
--- a/datafusion/physical-expr/src/equivalence.rs
+++ b/datafusion/physical-expr/src/equivalence.rs
@@ -23,27 +23,48 @@ use std::collections::HashMap;
use std::collections::HashSet;
/// Equivalence Properties is a vec of EquivalentClass.
-#[derive(Debug, Default, Clone)]
+#[derive(Debug, Clone)]
pub struct EquivalenceProperties {
classes: Vec<EquivalentClass>,
+ schema: SchemaRef,
}
impl EquivalenceProperties {
- pub fn new() -> Self {
- EquivalenceProperties { classes: vec![] }
+ pub fn new(schema: SchemaRef) -> Self {
+ EquivalenceProperties {
+ classes: vec![],
+ schema,
+ }
}
pub fn classes(&self) -> &[EquivalentClass] {
&self.classes
}
+ pub fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+
pub fn extend<I: IntoIterator<Item = EquivalentClass>>(&mut self, iter: I) {
- self.classes.extend(iter)
+ for ec in iter {
+ for column in ec.iter() {
+ assert_eq!(column.name(), self.schema.fields()[column.index()].name());
+ }
+ self.classes.push(ec)
+ }
}
/// Add new equal conditions into the EquivalenceProperties, the new equal conditions are usually comming from the
/// equality predicates in Join or Filter
pub fn add_equal_conditions(&mut self, new_conditions: (&Column, &Column)) {
+ assert_eq!(
+ new_conditions.0.name(),
+ self.schema.fields()[new_conditions.0.index()].name()
+ );
+ assert_eq!(
+ new_conditions.1.name(),
+ self.schema.fields()[new_conditions.1.index()].name()
+ );
let mut idx1: Option<usize> = None;
let mut idx2: Option<usize> = None;
for (idx, class) in self.classes.iter_mut().enumerate() {
@@ -89,47 +110,6 @@ impl EquivalenceProperties {
_ => {}
}
}
-
- pub fn merge_properties_with_alias(
- &mut self,
- alias_map: &HashMap<Column, Vec<Column>>,
- ) {
- for (column, columns) in alias_map {
- let mut find_match = false;
- for class in self.classes.iter_mut() {
- if class.contains(column) {
- for col in columns {
- class.insert(col.clone());
- }
- find_match = true;
- break;
- }
- }
- if !find_match {
- self.classes
- .push(EquivalentClass::new(column.clone(), columns.clone()));
- }
- }
- }
-
- pub fn truncate_properties_not_in_schema(&mut self, schema: &SchemaRef) {
- for class in self.classes.iter_mut() {
- let mut columns_to_remove = vec![];
- for column in class.iter() {
- if let Ok(idx) = schema.index_of(column.name()) {
- if idx != column.index() {
- columns_to_remove.push(column.clone());
- }
- } else {
- columns_to_remove.push(column.clone());
- }
- }
- for column in columns_to_remove {
- class.remove(&column);
- }
- }
- self.classes.retain(|props| props.len() > 1);
- }
}
/// Equivalent Class is a set of Columns that are known to have the same value in all tuples in a relation
@@ -195,15 +175,70 @@ impl EquivalentClass {
}
}
+/// Project Equivalence Properties.
+/// 1) Add Alias, Alias can introduce additional equivalence properties,
+/// For example: Projection(a, a as a1, a as a2)
+/// 2) Truncate the EquivalentClasses that are not in the output schema
+pub fn project_equivalence_properties(
+ input_eq: EquivalenceProperties,
+ alias_map: &HashMap<Column, Vec<Column>>,
+ output_eq: &mut EquivalenceProperties,
+) {
+ let mut ec_classes = input_eq.classes().to_vec();
+ for (column, columns) in alias_map {
+ let mut find_match = false;
+ for class in ec_classes.iter_mut() {
+ if class.contains(column) {
+ for col in columns {
+ class.insert(col.clone());
+ }
+ find_match = true;
+ break;
+ }
+ }
+ if !find_match {
+ ec_classes.push(EquivalentClass::new(column.clone(), columns.clone()));
+ }
+ }
+
+ let schema = output_eq.schema();
+ for class in ec_classes.iter_mut() {
+ let mut columns_to_remove = vec![];
+ for column in class.iter() {
+ if column.index() >= schema.fields().len()
+ || schema.fields()[column.index()].name() != column.name()
+ {
+ columns_to_remove.push(column.clone());
+ }
+ }
+ for column in columns_to_remove {
+ class.remove(&column);
+ }
+ }
+ ec_classes.retain(|props| props.len() > 1);
+ output_eq.extend(ec_classes);
+}
+
#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::Column;
+ use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Result;
+ use std::sync::Arc;
+
#[test]
fn add_equal_conditions_test() -> Result<()> {
- let mut eq_properties = EquivalenceProperties::new();
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int64, true),
+ Field::new("b", DataType::Int64, true),
+ Field::new("c", DataType::Int64, true),
+ Field::new("x", DataType::Int64, true),
+ Field::new("y", DataType::Int64, true),
+ ]));
+
+ let mut eq_properties = EquivalenceProperties::new(schema);
let new_condition = (&Column::new("a", 0), &Column::new("b", 1));
eq_properties.add_equal_conditions(new_condition);
assert_eq!(eq_properties.classes().len(), 1);
@@ -218,11 +253,11 @@ mod tests {
assert_eq!(eq_properties.classes().len(), 1);
assert_eq!(eq_properties.classes()[0].len(), 3);
- let new_condition = (&Column::new("x", 99), &Column::new("y", 100));
+ let new_condition = (&Column::new("x", 3), &Column::new("y", 4));
eq_properties.add_equal_conditions(new_condition);
assert_eq!(eq_properties.classes().len(), 2);
- let new_condition = (&Column::new("x", 99), &Column::new("a", 0));
+ let new_condition = (&Column::new("x", 3), &Column::new("a", 0));
eq_properties.add_equal_conditions(new_condition);
assert_eq!(eq_properties.classes().len(), 1);
assert_eq!(eq_properties.classes()[0].len(), 5);
@@ -231,26 +266,42 @@ mod tests {
}
#[test]
- fn merge_equivalence_properties_with_alias_test() -> Result<()> {
- let mut eq_properties = EquivalenceProperties::new();
- let mut alias_map = HashMap::new();
- alias_map.insert(
- Column::new("a", 0),
- vec![Column::new("a1", 1), Column::new("a2", 2)],
- );
+ fn project_equivalence_properties_test() -> Result<()> {
+ let input_schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int64, true),
+ Field::new("b", DataType::Int64, true),
+ Field::new("c", DataType::Int64, true),
+ ]));
- eq_properties.merge_properties_with_alias(&alias_map);
- assert_eq!(eq_properties.classes().len(), 1);
- assert_eq!(eq_properties.classes()[0].len(), 3);
+ let mut input_properties = EquivalenceProperties::new(input_schema);
+ let new_condition = (&Column::new("a", 0), &Column::new("b", 1));
+ input_properties.add_equal_conditions(new_condition);
+ let new_condition = (&Column::new("b", 1), &Column::new("c", 2));
+ input_properties.add_equal_conditions(new_condition);
+
+ let out_schema = Arc::new(Schema::new(vec![
+ Field::new("a1", DataType::Int64, true),
+ Field::new("a2", DataType::Int64, true),
+ Field::new("a3", DataType::Int64, true),
+ Field::new("a4", DataType::Int64, true),
+ ]));
let mut alias_map = HashMap::new();
alias_map.insert(
Column::new("a", 0),
- vec![Column::new("a3", 1), Column::new("a4", 2)],
+ vec![
+ Column::new("a1", 0),
+ Column::new("a2", 1),
+ Column::new("a3", 2),
+ Column::new("a4", 3),
+ ],
);
- eq_properties.merge_properties_with_alias(&alias_map);
- assert_eq!(eq_properties.classes().len(), 1);
- assert_eq!(eq_properties.classes()[0].len(), 5);
+ let mut out_properties = EquivalenceProperties::new(out_schema);
+
+ project_equivalence_properties(input_properties, &alias_map, &mut out_properties);
+ assert_eq!(out_properties.classes().len(), 1);
+ assert_eq!(out_properties.classes()[0].len(), 4);
+
Ok(())
}
}