You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by dh...@apache.org on 2023/06/26 21:08:36 UTC

[arrow-datafusion] 01/01: Use projection

This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch add_join_projection
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit e8edecfd4efdd5d0a1af6439fc59736313a5ce3e
Author: Daniƫl Heres <da...@coralogix.com>
AuthorDate: Mon Jun 26 23:08:23 2023 +0200

    Use projection
---
 .../src/physical_optimizer/dist_enforcement.rs     |   5 +
 .../core/src/physical_optimizer/join_selection.rs  |  11 +++
 .../core/src/physical_optimizer/pipeline_fixer.rs  |   1 +
 .../core/src/physical_plan/joins/hash_join.rs      |  11 ++-
 .../src/physical_plan/joins/nested_loop_join.rs    |   2 +-
 .../src/physical_plan/joins/sort_merge_join.rs     |   2 +-
 .../src/physical_plan/joins/symmetric_hash_join.rs |   5 +-
 datafusion/core/src/physical_plan/joins/utils.rs   |   3 +-
 datafusion/core/src/physical_planner.rs            |   5 +
 datafusion/expr/src/logical_plan/builder.rs        | 102 ++++++++++++---------
 datafusion/expr/src/logical_plan/plan.rs           |  11 ++-
 datafusion/expr/src/utils.rs                       |  10 +-
 datafusion/optimizer/src/eliminate_cross_join.rs   |   3 +
 datafusion/optimizer/src/eliminate_outer_join.rs   |   1 +
 .../optimizer/src/extract_equijoin_predicate.rs    |   2 +
 datafusion/optimizer/src/push_down_limit.rs        |   1 +
 datafusion/optimizer/src/push_down_projection.rs   |  13 ++-
 17 files changed, 134 insertions(+), 54 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/dist_enforcement.rs b/datafusion/core/src/physical_optimizer/dist_enforcement.rs
index cb98e69d7a..1e539d18f6 100644
--- a/datafusion/core/src/physical_optimizer/dist_enforcement.rs
+++ b/datafusion/core/src/physical_optimizer/dist_enforcement.rs
@@ -154,6 +154,7 @@ fn adjust_input_keys_ordering(
         join_type,
         mode,
         null_equals_null,
+        projection,
         ..
     }) = plan_any.downcast_ref::<HashJoinExec>()
     {
@@ -169,6 +170,7 @@ fn adjust_input_keys_ordering(
                             join_type,
                             PartitionMode::Partitioned,
                             *null_equals_null,
+                            projection.clone(),
                         )?) as Arc<dyn ExecutionPlan>)
                     };
                 Some(reorder_partitioned_join_keys(
@@ -541,6 +543,7 @@ fn reorder_join_keys_to_inputs(
         join_type,
         mode,
         null_equals_null,
+        projection,
         ..
     }) = plan_any.downcast_ref::<HashJoinExec>()
     {
@@ -570,6 +573,7 @@ fn reorder_join_keys_to_inputs(
                             join_type,
                             PartitionMode::Partitioned,
                             *null_equals_null,
+                            projection.clone(),
                         )?))
                     } else {
                         Ok(plan)
@@ -1123,6 +1127,7 @@ mod tests {
                 join_type,
                 PartitionMode::Partitioned,
                 false,
+                None,
             )
             .unwrap(),
         )
diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs
index a9dec73c36..ac24c6275f 100644
--- a/datafusion/core/src/physical_optimizer/join_selection.rs
+++ b/datafusion/core/src/physical_optimizer/join_selection.rs
@@ -137,6 +137,7 @@ pub fn swap_hash_join(
         &swap_join_type(*hash_join.join_type()),
         partition_mode,
         hash_join.null_equals_null(),
+        None,
     )?;
     if matches!(
         hash_join.join_type(),
@@ -333,6 +334,7 @@ fn try_collect_left(
                     hash_join.join_type(),
                     PartitionMode::CollectLeft,
                     hash_join.null_equals_null(),
+                    hash_join.projection.clone(),
                 )?)))
             }
         }
@@ -344,6 +346,7 @@ fn try_collect_left(
             hash_join.join_type(),
             PartitionMode::CollectLeft,
             hash_join.null_equals_null(),
+            hash_join.projection.clone(),
         )?))),
         (false, true) => {
             if supports_swap(*hash_join.join_type()) {
@@ -371,6 +374,7 @@ fn partitioned_hash_join(hash_join: &HashJoinExec) -> Result<Arc<dyn ExecutionPl
             hash_join.join_type(),
             PartitionMode::Partitioned,
             hash_join.null_equals_null(),
+            hash_join.projection.clone(),
         )?))
     }
 }
@@ -495,6 +499,7 @@ mod tests {
             &JoinType::Left,
             PartitionMode::CollectLeft,
             false,
+            None,
         )
         .unwrap();
 
@@ -543,6 +548,7 @@ mod tests {
             &JoinType::Left,
             PartitionMode::CollectLeft,
             false,
+            None,
         )
         .unwrap();
 
@@ -594,6 +600,7 @@ mod tests {
                 &join_type,
                 PartitionMode::Partitioned,
                 false,
+                None,
             )
             .unwrap();
 
@@ -659,6 +666,7 @@ mod tests {
             &JoinType::Inner,
             PartitionMode::CollectLeft,
             false,
+            None,
         )
         .unwrap();
         let child_schema = child_join.schema();
@@ -675,6 +683,7 @@ mod tests {
             &JoinType::Left,
             PartitionMode::CollectLeft,
             false,
+            None,
         )
         .unwrap();
 
@@ -712,6 +721,7 @@ mod tests {
             &JoinType::Inner,
             PartitionMode::CollectLeft,
             false,
+            None,
         )
         .unwrap();
 
@@ -937,6 +947,7 @@ mod tests {
             &JoinType::Inner,
             PartitionMode::Auto,
             false,
+            None,
         )
         .unwrap();
 
diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs
index caae774345..cfe34558c4 100644
--- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs
+++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs
@@ -644,6 +644,7 @@ mod hash_join_tests {
             &t.initial_join_type,
             t.initial_mode,
             false,
+            None,
         )?;
 
         let initial_hash_join_state = PipelineStatePropagator {
diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs
index a3c553c9b3..b38811eefd 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join.rs
@@ -122,6 +122,8 @@ pub struct HashJoinExec {
     column_indices: Vec<ColumnIndex>,
     /// If null_equals_null is true, null == null else null != null
     pub(crate) null_equals_null: bool,
+    /// Optional output projection
+    pub projection: Option<Vec<Column>>,
 }
 
 impl HashJoinExec {
@@ -136,6 +138,7 @@ impl HashJoinExec {
         join_type: &JoinType,
         partition_mode: PartitionMode,
         null_equals_null: bool,
+        projection: Option<Vec<Column>>,
     ) -> Result<Self> {
         let left_schema = left.schema();
         let right_schema = right.schema();
@@ -148,7 +151,7 @@ impl HashJoinExec {
         check_join_is_valid(&left_schema, &right_schema, &on)?;
 
         let (schema, column_indices) =
-            build_join_schema(&left_schema, &right_schema, join_type);
+            build_join_schema(&left_schema, &right_schema, join_type, projection);
 
         let random_state = RandomState::with_seeds(0, 0, 0, 0);
 
@@ -165,6 +168,7 @@ impl HashJoinExec {
             metrics: ExecutionPlanMetricsSet::new(),
             column_indices,
             null_equals_null,
+            projection,
         })
     }
 
@@ -337,6 +341,7 @@ impl ExecutionPlan for HashJoinExec {
             &self.join_type,
             self.mode,
             self.null_equals_null,
+            self.projection,
         )?))
     }
 
@@ -1358,6 +1363,7 @@ mod tests {
             join_type,
             PartitionMode::CollectLeft,
             null_equals_null,
+            None,
         )
     }
 
@@ -1377,6 +1383,7 @@ mod tests {
             join_type,
             PartitionMode::CollectLeft,
             null_equals_null,
+            None,
         )
     }
 
@@ -1431,6 +1438,7 @@ mod tests {
             join_type,
             PartitionMode::Partitioned,
             null_equals_null,
+            None,
         )?;
 
         let columns = columns(&join.schema());
@@ -3164,6 +3172,7 @@ mod tests {
                 &join_type,
                 PartitionMode::Partitioned,
                 false,
+                None,
             )?;
 
             let stream = join.execute(1, task_ctx)?;
diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
index 6586456fd2..5a2b0e8337 100644
--- a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
+++ b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
@@ -106,7 +106,7 @@ impl NestedLoopJoinExec {
         let right_schema = right.schema();
         check_join_is_valid(&left_schema, &right_schema, &[])?;
         let (schema, column_indices) =
-            build_join_schema(&left_schema, &right_schema, join_type);
+            build_join_schema(&left_schema, &right_schema, join_type, None);
         Ok(NestedLoopJoinExec {
             left,
             right,
diff --git a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs
index bc8c686670..324f8582e9 100644
--- a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs
+++ b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs
@@ -177,7 +177,7 @@ impl SortMergeJoinExec {
         };
 
         let schema =
-            Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
+            Arc::new(build_join_schema(&left_schema, &right_schema, &join_type, None).0);
 
         Ok(Self {
             left,
diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
index b46aba2fb5..7df848e3eb 100644
--- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
@@ -295,7 +295,7 @@ impl SymmetricHashJoinExec {
 
         // Build the join schema from the left and right schemas:
         let (schema, column_indices) =
-            build_join_schema(&left_schema, &right_schema, join_type);
+            build_join_schema(&left_schema, &right_schema, join_type, None);
 
         // Initialize the random state for the join operation:
         let random_state = RandomState::with_seeds(0, 0, 0, 0);
@@ -1862,6 +1862,7 @@ mod tests {
             join_type,
             PartitionMode::Partitioned,
             null_equals_null,
+            None,
         )?;
 
         let mut batches = vec![];
@@ -3026,7 +3027,7 @@ mod tests {
 
         // Build the join schema from the left and right schemas
         let (schema, join_column_indices) =
-            build_join_schema(&left_schema, &right_schema, &join_type);
+            build_join_schema(&left_schema, &right_schema, &join_type, None);
         let join_schema = Arc::new(schema);
 
         // Sort information for MemoryExec
diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs
index 627bdeebc5..772201315d 100644
--- a/datafusion/core/src/physical_plan/joins/utils.rs
+++ b/datafusion/core/src/physical_plan/joins/utils.rs
@@ -350,6 +350,7 @@ pub fn build_join_schema(
     left: &Schema,
     right: &Schema,
     join_type: &JoinType,
+    projection: Option<Vec<Column>>,
 ) -> (Schema, Vec<ColumnIndex>) {
     let (fields, column_indices): (SchemaBuilder, Vec<ColumnIndex>) = match join_type {
         JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
@@ -1197,7 +1198,7 @@ mod tests {
         ];
 
         for (left_in, right_in, join_type, left_out, right_out) in cases {
-            let (schema, _) = build_join_schema(left_in, right_in, &join_type);
+            let (schema, _) = build_join_schema(left_in, right_in, &join_type, None);
 
             let expected_fields = left_out
                 .fields()
diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs
index 75566208e3..525a6e34fb 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -901,6 +901,7 @@ impl DefaultPhysicalPlanner {
                     join_type,
                     null_equals_null,
                     schema: join_schema,
+                    projection,
                     ..
                 }) => {
                     let null_equals_null = *null_equals_null;
@@ -990,6 +991,8 @@ impl DefaultPhysicalPlanner {
                         })
                         .collect::<Result<join_utils::JoinOn>>()?;
 
+                    let projection: Option<Vec<Column>> = projection.map(|proj|proj.iter().enumerate().map(|col|Column::new(col.name, 0)).collect());
+
                     let join_filter = match filter {
                         Some(expr) => {
                             // Extract columns from filter expression and saved in a HashSet
@@ -1095,6 +1098,7 @@ impl DefaultPhysicalPlanner {
                             join_type,
                             partition_mode,
                             null_equals_null,
+                            projection.clone(),
                         )?))
                     } else {
                         Ok(Arc::new(HashJoinExec::try_new(
@@ -1105,6 +1109,7 @@ impl DefaultPhysicalPlanner {
                             join_type,
                             PartitionMode::CollectLeft,
                             null_equals_null,
+                            projection.clone(),
                         )?))
                     }
                 }
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index 3d34c087ac..34ae0eb61a 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -721,7 +721,7 @@ impl LogicalPlanBuilder {
             .map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
             .collect();
         let join_schema =
-            build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
+            build_join_schema(self.plan.schema(), right.schema(), &join_type, None)?;
 
         Ok(Self::from(LogicalPlan::Join(Join {
             left: Arc::new(self.plan),
@@ -732,6 +732,7 @@ impl LogicalPlanBuilder {
             join_constraint: JoinConstraint::On,
             schema: DFSchemaRef::new(join_schema),
             null_equals_null,
+            projection: None,
         })))
     }
 
@@ -754,7 +755,7 @@ impl LogicalPlanBuilder {
 
         let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect();
         let join_schema =
-            build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
+            build_join_schema(self.plan.schema(), right.schema(), &join_type, None)?;
         let mut join_on: Vec<(Expr, Expr)> = vec![];
         let mut filters: Option<Expr> = None;
         for (l, r) in &on {
@@ -796,6 +797,7 @@ impl LogicalPlanBuilder {
                 join_constraint: JoinConstraint::Using,
                 schema: DFSchemaRef::new(join_schema),
                 null_equals_null: false,
+                projection: None,
             })))
         }
     }
@@ -1012,7 +1014,7 @@ impl LogicalPlanBuilder {
             .collect::<Result<Vec<_>>>()?;
 
         let join_schema =
-            build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
+            build_join_schema(self.plan.schema(), right.schema(), &join_type, None)?;
 
         Ok(Self::from(LogicalPlan::Join(Join {
             left: Arc::new(self.plan),
@@ -1023,6 +1025,7 @@ impl LogicalPlanBuilder {
             join_constraint: JoinConstraint::On,
             schema: DFSchemaRef::new(join_schema),
             null_equals_null: false,
+            projection: None,
         })))
     }
 
@@ -1038,6 +1041,7 @@ pub fn build_join_schema(
     left: &DFSchema,
     right: &DFSchema,
     join_type: &JoinType,
+    projection: Option<&Vec<Column>>,
 ) -> Result<DFSchema> {
     fn nullify_fields(fields: &[DFField]) -> Vec<DFField> {
         fields
@@ -1049,51 +1053,65 @@ pub fn build_join_schema(
     let right_fields = right.fields();
     let left_fields = left.fields();
 
-    let fields: Vec<DFField> = match join_type {
-        JoinType::Inner => {
-            // left then right
-            left_fields
+    let fields = {
+        if let Some(projection) = projection {
+            projection
                 .iter()
-                .chain(right_fields.iter())
-                .cloned()
-                .collect()
-        }
-        JoinType::Left => {
-            // left then right, right set to nullable in case of not matched scenario
-            left_fields
-                .iter()
-                .chain(&nullify_fields(right_fields))
-                .cloned()
-                .collect()
-        }
-        JoinType::Right => {
-            // left then right, left set to nullable in case of not matched scenario
-            nullify_fields(left_fields)
-                .iter()
-                .chain(right_fields.iter())
-                .cloned()
-                .collect()
-        }
-        JoinType::Full => {
-            // left then right, all set to nullable in case of not matched scenario
-            nullify_fields(left_fields)
-                .iter()
-                .chain(&nullify_fields(right_fields))
-                .cloned()
-                .collect()
-        }
-        JoinType::LeftSemi | JoinType::LeftAnti => {
-            // Only use the left side for the schema
-            left_fields.clone()
-        }
-        JoinType::RightSemi | JoinType::RightAnti => {
-            // Only use the right side for the schema
-            right_fields.clone()
+                .map(|col| {
+                    left.field_from_column(col)
+                        .or_else(|_| right.field_from_column(col))
+                        .cloned()
+                })
+                .collect::<Result<Vec<DFField>>>()?
+        } else {
+            match join_type {
+                JoinType::Inner => {
+                    // left then right
+                    left_fields
+                        .iter()
+                        .chain(right_fields.iter())
+                        .cloned()
+                        .collect()
+                }
+                JoinType::Left => {
+                    // left then right, right set to nullable in case of not matched scenario
+                    left_fields
+                        .iter()
+                        .chain(&nullify_fields(right_fields))
+                        .cloned()
+                        .collect()
+                }
+                JoinType::Right => {
+                    // left then right, left set to nullable in case of not matched scenario
+                    nullify_fields(left_fields)
+                        .iter()
+                        .chain(right_fields.iter())
+                        .cloned()
+                        .collect()
+                }
+                JoinType::Full => {
+                    // left then right, all set to nullable in case of not matched scenario
+                    nullify_fields(left_fields)
+                        .iter()
+                        .chain(&nullify_fields(right_fields))
+                        .cloned()
+                        .collect()
+                }
+                JoinType::LeftSemi | JoinType::LeftAnti => {
+                    // Only use the left side for the schema
+                    left_fields.clone()
+                }
+                JoinType::RightSemi | JoinType::RightAnti => {
+                    // Only use the right side for the schema
+                    right_fields.clone()
+                }
+            }
         }
     };
 
     let mut metadata = left.metadata().clone();
     metadata.extend(right.metadata().clone());
+
     DFSchema::new_with_metadata(fields, metadata)
 }
 
diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs
index ab45047acf..d4517917ec 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -1660,6 +1660,8 @@ pub struct Join {
     pub schema: DFSchemaRef,
     /// If null_equals_null is true, null == null else null != null
     pub null_equals_null: bool,
+    /// optional projection
+    pub projection: Option<Vec<Column>>,
 }
 
 impl Join {
@@ -1681,8 +1683,12 @@ impl Join {
             .zip(column_on.1.into_iter())
             .map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
             .collect();
-        let join_schema =
-            build_join_schema(left.schema(), right.schema(), &original_join.join_type)?;
+        let join_schema = build_join_schema(
+            left.schema(),
+            right.schema(),
+            &original_join.join_type,
+            original_join.projection.as_ref(),
+        )?;
 
         Ok(Join {
             left,
@@ -1693,6 +1699,7 @@ impl Join {
             join_constraint: original_join.join_constraint,
             schema: Arc::new(join_schema),
             null_equals_null: original_join.null_equals_null,
+            projection: None,
         })
     }
 }
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 2b6fc5793a..2ae116e900 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -845,10 +845,15 @@ pub fn from_plan(
             join_constraint,
             on,
             null_equals_null,
+            projection,
             ..
         }) => {
-            let schema =
-                build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?;
+            let schema = build_join_schema(
+                inputs[0].schema(),
+                inputs[1].schema(),
+                join_type,
+                projection.as_ref(),
+            )?;
 
             let equi_expr_count = on.len();
             assert!(expr.len() >= equi_expr_count);
@@ -881,6 +886,7 @@ pub fn from_plan(
                 filter: filter_expr,
                 schema: DFSchemaRef::new(schema),
                 null_equals_null: *null_equals_null,
+                projection: projection.clone(),
             }))
         }
         LogicalPlan::CrossJoin(_) => {
diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs
index 533566a0bf..e32311d67d 100644
--- a/datafusion/optimizer/src/eliminate_cross_join.rs
+++ b/datafusion/optimizer/src/eliminate_cross_join.rs
@@ -214,6 +214,7 @@ fn find_inner_join(
                 left_input.schema(),
                 right_input.schema(),
                 &JoinType::Inner,
+                None,
             )?);
 
             return Ok(LogicalPlan::Join(Join {
@@ -225,6 +226,7 @@ fn find_inner_join(
                 filter: None,
                 schema: join_schema,
                 null_equals_null: false,
+                projection: None,
             }));
         }
     }
@@ -233,6 +235,7 @@ fn find_inner_join(
         left_input.schema(),
         right.schema(),
         &JoinType::Inner,
+        None,
     )?);
 
     Ok(LogicalPlan::CrossJoin(CrossJoin {
diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs
index e4d57f0209..e2df09da56 100644
--- a/datafusion/optimizer/src/eliminate_outer_join.rs
+++ b/datafusion/optimizer/src/eliminate_outer_join.rs
@@ -105,6 +105,7 @@ impl OptimizerRule for EliminateOuterJoin {
                         filter: join.filter.clone(),
                         schema: join.schema.clone(),
                         null_equals_null: join.null_equals_null,
+                        projection: join.projection.clone(),
                     });
                     let new_plan = plan.with_new_inputs(&[new_join])?;
                     Ok(Some(new_plan))
diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs
index 20b9c62971..41db207d53 100644
--- a/datafusion/optimizer/src/extract_equijoin_predicate.rs
+++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs
@@ -55,6 +55,7 @@ impl OptimizerRule for ExtractEquijoinPredicate {
                 join_constraint,
                 schema,
                 null_equals_null,
+                projection,
             }) => {
                 let left_schema = left.schema();
                 let right_schema = right.schema();
@@ -80,6 +81,7 @@ impl OptimizerRule for ExtractEquijoinPredicate {
                             join_constraint: *join_constraint,
                             schema: schema.clone(),
                             null_equals_null: *null_equals_null,
+                            projection: projection.clone(),
                         })
                     });
 
diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs
index 6703a1d787..42915b2e3c 100644
--- a/datafusion/optimizer/src/push_down_limit.rs
+++ b/datafusion/optimizer/src/push_down_limit.rs
@@ -266,6 +266,7 @@ fn push_down_join(join: &Join, limit: usize) -> Option<Join> {
                 join_constraint: join.join_constraint,
                 schema: join.schema.clone(),
                 null_equals_null: join.null_equals_null,
+                projection: join.projection.clone(),
             })
         }
     }
diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs
index 4773a944f4..65409b0bd0 100644
--- a/datafusion/optimizer/src/push_down_projection.rs
+++ b/datafusion/optimizer/src/push_down_projection.rs
@@ -101,6 +101,10 @@ impl OptimizerRule for PushDownProjection {
                 for e in projection.expr.iter() {
                     expr_to_columns(e, &mut push_columns)?;
                 }
+
+                // Keep columns to use for join output projection
+                let output_columns = push_columns.clone();
+
                 for (l, r) in join.on.iter() {
                     expr_to_columns(l, &mut push_columns)?;
                     expr_to_columns(r, &mut push_columns)?;
@@ -119,9 +123,14 @@ impl OptimizerRule for PushDownProjection {
                     join.right.schema(),
                     join.right.clone(),
                 )?;
-                let new_join = child_plan.with_new_inputs(&[new_left, new_right])?;
 
-                generate_plan!(projection_is_empty, plan, new_join)
+                let mut join = join.clone();
+
+                join.left = Arc::new(new_left);
+                join.right = Arc::new(new_right);
+                join.projection = Some(output_columns.into_iter().collect());
+
+                generate_plan!(projection_is_empty, plan, LogicalPlan::Join(join))
             }
             LogicalPlan::CrossJoin(join) => {
                 // collect column in on/filter in join and projection.