You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by dh...@apache.org on 2023/06/26 21:08:35 UTC

[arrow-datafusion] branch add_join_projection created (now e8edecfd4e)

This is an automated email from the ASF dual-hosted git repository.

dheres pushed a change to branch add_join_projection
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


      at e8edecfd4e Use projection

This branch includes the following new commits:

     new e8edecfd4e Use projection

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[arrow-datafusion] 01/01: Use projection

Posted by dh...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch add_join_projection
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit e8edecfd4efdd5d0a1af6439fc59736313a5ce3e
Author: Daniƫl Heres <da...@coralogix.com>
AuthorDate: Mon Jun 26 23:08:23 2023 +0200

    Use projection
---
 .../src/physical_optimizer/dist_enforcement.rs     |   5 +
 .../core/src/physical_optimizer/join_selection.rs  |  11 +++
 .../core/src/physical_optimizer/pipeline_fixer.rs  |   1 +
 .../core/src/physical_plan/joins/hash_join.rs      |  11 ++-
 .../src/physical_plan/joins/nested_loop_join.rs    |   2 +-
 .../src/physical_plan/joins/sort_merge_join.rs     |   2 +-
 .../src/physical_plan/joins/symmetric_hash_join.rs |   5 +-
 datafusion/core/src/physical_plan/joins/utils.rs   |   3 +-
 datafusion/core/src/physical_planner.rs            |   5 +
 datafusion/expr/src/logical_plan/builder.rs        | 102 ++++++++++++---------
 datafusion/expr/src/logical_plan/plan.rs           |  11 ++-
 datafusion/expr/src/utils.rs                       |  10 +-
 datafusion/optimizer/src/eliminate_cross_join.rs   |   3 +
 datafusion/optimizer/src/eliminate_outer_join.rs   |   1 +
 .../optimizer/src/extract_equijoin_predicate.rs    |   2 +
 datafusion/optimizer/src/push_down_limit.rs        |   1 +
 datafusion/optimizer/src/push_down_projection.rs   |  13 ++-
 17 files changed, 134 insertions(+), 54 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/dist_enforcement.rs b/datafusion/core/src/physical_optimizer/dist_enforcement.rs
index cb98e69d7a..1e539d18f6 100644
--- a/datafusion/core/src/physical_optimizer/dist_enforcement.rs
+++ b/datafusion/core/src/physical_optimizer/dist_enforcement.rs
@@ -154,6 +154,7 @@ fn adjust_input_keys_ordering(
         join_type,
         mode,
         null_equals_null,
+        projection,
         ..
     }) = plan_any.downcast_ref::<HashJoinExec>()
     {
@@ -169,6 +170,7 @@ fn adjust_input_keys_ordering(
                             join_type,
                             PartitionMode::Partitioned,
                             *null_equals_null,
+                            projection.clone(),
                         )?) as Arc<dyn ExecutionPlan>)
                     };
                 Some(reorder_partitioned_join_keys(
@@ -541,6 +543,7 @@ fn reorder_join_keys_to_inputs(
         join_type,
         mode,
         null_equals_null,
+        projection,
         ..
     }) = plan_any.downcast_ref::<HashJoinExec>()
     {
@@ -570,6 +573,7 @@ fn reorder_join_keys_to_inputs(
                             join_type,
                             PartitionMode::Partitioned,
                             *null_equals_null,
+                            projection.clone(),
                         )?))
                     } else {
                         Ok(plan)
@@ -1123,6 +1127,7 @@ mod tests {
                 join_type,
                 PartitionMode::Partitioned,
                 false,
+                None,
             )
             .unwrap(),
         )
diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs
index a9dec73c36..ac24c6275f 100644
--- a/datafusion/core/src/physical_optimizer/join_selection.rs
+++ b/datafusion/core/src/physical_optimizer/join_selection.rs
@@ -137,6 +137,7 @@ pub fn swap_hash_join(
         &swap_join_type(*hash_join.join_type()),
         partition_mode,
         hash_join.null_equals_null(),
+        None,
     )?;
     if matches!(
         hash_join.join_type(),
@@ -333,6 +334,7 @@ fn try_collect_left(
                     hash_join.join_type(),
                     PartitionMode::CollectLeft,
                     hash_join.null_equals_null(),
+                    hash_join.projection.clone(),
                 )?)))
             }
         }
@@ -344,6 +346,7 @@ fn try_collect_left(
             hash_join.join_type(),
             PartitionMode::CollectLeft,
             hash_join.null_equals_null(),
+            hash_join.projection.clone(),
         )?))),
         (false, true) => {
             if supports_swap(*hash_join.join_type()) {
@@ -371,6 +374,7 @@ fn partitioned_hash_join(hash_join: &HashJoinExec) -> Result<Arc<dyn ExecutionPl
             hash_join.join_type(),
             PartitionMode::Partitioned,
             hash_join.null_equals_null(),
+            hash_join.projection.clone(),
         )?))
     }
 }
@@ -495,6 +499,7 @@ mod tests {
             &JoinType::Left,
             PartitionMode::CollectLeft,
             false,
+            None,
         )
         .unwrap();
 
@@ -543,6 +548,7 @@ mod tests {
             &JoinType::Left,
             PartitionMode::CollectLeft,
             false,
+            None,
         )
         .unwrap();
 
@@ -594,6 +600,7 @@ mod tests {
                 &join_type,
                 PartitionMode::Partitioned,
                 false,
+                None,
             )
             .unwrap();
 
@@ -659,6 +666,7 @@ mod tests {
             &JoinType::Inner,
             PartitionMode::CollectLeft,
             false,
+            None,
         )
         .unwrap();
         let child_schema = child_join.schema();
@@ -675,6 +683,7 @@ mod tests {
             &JoinType::Left,
             PartitionMode::CollectLeft,
             false,
+            None,
         )
         .unwrap();
 
@@ -712,6 +721,7 @@ mod tests {
             &JoinType::Inner,
             PartitionMode::CollectLeft,
             false,
+            None,
         )
         .unwrap();
 
@@ -937,6 +947,7 @@ mod tests {
             &JoinType::Inner,
             PartitionMode::Auto,
             false,
+            None,
         )
         .unwrap();
 
diff --git a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs
index caae774345..cfe34558c4 100644
--- a/datafusion/core/src/physical_optimizer/pipeline_fixer.rs
+++ b/datafusion/core/src/physical_optimizer/pipeline_fixer.rs
@@ -644,6 +644,7 @@ mod hash_join_tests {
             &t.initial_join_type,
             t.initial_mode,
             false,
+            None,
         )?;
 
         let initial_hash_join_state = PipelineStatePropagator {
diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs
index a3c553c9b3..b38811eefd 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join.rs
@@ -122,6 +122,8 @@ pub struct HashJoinExec {
     column_indices: Vec<ColumnIndex>,
     /// If null_equals_null is true, null == null else null != null
     pub(crate) null_equals_null: bool,
+    /// Optional output projection
+    pub projection: Option<Vec<Column>>,
 }
 
 impl HashJoinExec {
@@ -136,6 +138,7 @@ impl HashJoinExec {
         join_type: &JoinType,
         partition_mode: PartitionMode,
         null_equals_null: bool,
+        projection: Option<Vec<Column>>,
     ) -> Result<Self> {
         let left_schema = left.schema();
         let right_schema = right.schema();
@@ -148,7 +151,7 @@ impl HashJoinExec {
         check_join_is_valid(&left_schema, &right_schema, &on)?;
 
         let (schema, column_indices) =
-            build_join_schema(&left_schema, &right_schema, join_type);
+            build_join_schema(&left_schema, &right_schema, join_type, projection);
 
         let random_state = RandomState::with_seeds(0, 0, 0, 0);
 
@@ -165,6 +168,7 @@ impl HashJoinExec {
             metrics: ExecutionPlanMetricsSet::new(),
             column_indices,
             null_equals_null,
+            projection,
         })
     }
 
@@ -337,6 +341,7 @@ impl ExecutionPlan for HashJoinExec {
             &self.join_type,
             self.mode,
             self.null_equals_null,
+            self.projection,
         )?))
     }
 
@@ -1358,6 +1363,7 @@ mod tests {
             join_type,
             PartitionMode::CollectLeft,
             null_equals_null,
+            None,
         )
     }
 
@@ -1377,6 +1383,7 @@ mod tests {
             join_type,
             PartitionMode::CollectLeft,
             null_equals_null,
+            None,
         )
     }
 
@@ -1431,6 +1438,7 @@ mod tests {
             join_type,
             PartitionMode::Partitioned,
             null_equals_null,
+            None,
         )?;
 
         let columns = columns(&join.schema());
@@ -3164,6 +3172,7 @@ mod tests {
                 &join_type,
                 PartitionMode::Partitioned,
                 false,
+                None,
             )?;
 
             let stream = join.execute(1, task_ctx)?;
diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
index 6586456fd2..5a2b0e8337 100644
--- a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
+++ b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
@@ -106,7 +106,7 @@ impl NestedLoopJoinExec {
         let right_schema = right.schema();
         check_join_is_valid(&left_schema, &right_schema, &[])?;
         let (schema, column_indices) =
-            build_join_schema(&left_schema, &right_schema, join_type);
+            build_join_schema(&left_schema, &right_schema, join_type, None);
         Ok(NestedLoopJoinExec {
             left,
             right,
diff --git a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs
index bc8c686670..324f8582e9 100644
--- a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs
+++ b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs
@@ -177,7 +177,7 @@ impl SortMergeJoinExec {
         };
 
         let schema =
-            Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
+            Arc::new(build_join_schema(&left_schema, &right_schema, &join_type, None).0);
 
         Ok(Self {
             left,
diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
index b46aba2fb5..7df848e3eb 100644
--- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
@@ -295,7 +295,7 @@ impl SymmetricHashJoinExec {
 
         // Build the join schema from the left and right schemas:
         let (schema, column_indices) =
-            build_join_schema(&left_schema, &right_schema, join_type);
+            build_join_schema(&left_schema, &right_schema, join_type, None);
 
         // Initialize the random state for the join operation:
         let random_state = RandomState::with_seeds(0, 0, 0, 0);
@@ -1862,6 +1862,7 @@ mod tests {
             join_type,
             PartitionMode::Partitioned,
             null_equals_null,
+            None,
         )?;
 
         let mut batches = vec![];
@@ -3026,7 +3027,7 @@ mod tests {
 
         // Build the join schema from the left and right schemas
         let (schema, join_column_indices) =
-            build_join_schema(&left_schema, &right_schema, &join_type);
+            build_join_schema(&left_schema, &right_schema, &join_type, None);
         let join_schema = Arc::new(schema);
 
         // Sort information for MemoryExec
diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs
index 627bdeebc5..772201315d 100644
--- a/datafusion/core/src/physical_plan/joins/utils.rs
+++ b/datafusion/core/src/physical_plan/joins/utils.rs
@@ -350,6 +350,7 @@ pub fn build_join_schema(
     left: &Schema,
     right: &Schema,
     join_type: &JoinType,
+    projection: Option<Vec<Column>>,
 ) -> (Schema, Vec<ColumnIndex>) {
     let (fields, column_indices): (SchemaBuilder, Vec<ColumnIndex>) = match join_type {
         JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => {
@@ -1197,7 +1198,7 @@ mod tests {
         ];
 
         for (left_in, right_in, join_type, left_out, right_out) in cases {
-            let (schema, _) = build_join_schema(left_in, right_in, &join_type);
+            let (schema, _) = build_join_schema(left_in, right_in, &join_type, None);
 
             let expected_fields = left_out
                 .fields()
diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs
index 75566208e3..525a6e34fb 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -901,6 +901,7 @@ impl DefaultPhysicalPlanner {
                     join_type,
                     null_equals_null,
                     schema: join_schema,
+                    projection,
                     ..
                 }) => {
                     let null_equals_null = *null_equals_null;
@@ -990,6 +991,8 @@ impl DefaultPhysicalPlanner {
                         })
                         .collect::<Result<join_utils::JoinOn>>()?;
 
+                    let projection: Option<Vec<Column>> = projection.map(|proj|proj.iter().enumerate().map(|col|Column::new(col.name, 0)).collect());
+
                     let join_filter = match filter {
                         Some(expr) => {
                             // Extract columns from filter expression and saved in a HashSet
@@ -1095,6 +1098,7 @@ impl DefaultPhysicalPlanner {
                             join_type,
                             partition_mode,
                             null_equals_null,
+                            projection.clone(),
                         )?))
                     } else {
                         Ok(Arc::new(HashJoinExec::try_new(
@@ -1105,6 +1109,7 @@ impl DefaultPhysicalPlanner {
                             join_type,
                             PartitionMode::CollectLeft,
                             null_equals_null,
+                            projection.clone(),
                         )?))
                     }
                 }
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index 3d34c087ac..34ae0eb61a 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -721,7 +721,7 @@ impl LogicalPlanBuilder {
             .map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
             .collect();
         let join_schema =
-            build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
+            build_join_schema(self.plan.schema(), right.schema(), &join_type, None)?;
 
         Ok(Self::from(LogicalPlan::Join(Join {
             left: Arc::new(self.plan),
@@ -732,6 +732,7 @@ impl LogicalPlanBuilder {
             join_constraint: JoinConstraint::On,
             schema: DFSchemaRef::new(join_schema),
             null_equals_null,
+            projection: None,
         })))
     }
 
@@ -754,7 +755,7 @@ impl LogicalPlanBuilder {
 
         let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect();
         let join_schema =
-            build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
+            build_join_schema(self.plan.schema(), right.schema(), &join_type, None)?;
         let mut join_on: Vec<(Expr, Expr)> = vec![];
         let mut filters: Option<Expr> = None;
         for (l, r) in &on {
@@ -796,6 +797,7 @@ impl LogicalPlanBuilder {
                 join_constraint: JoinConstraint::Using,
                 schema: DFSchemaRef::new(join_schema),
                 null_equals_null: false,
+                projection: None,
             })))
         }
     }
@@ -1012,7 +1014,7 @@ impl LogicalPlanBuilder {
             .collect::<Result<Vec<_>>>()?;
 
         let join_schema =
-            build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
+            build_join_schema(self.plan.schema(), right.schema(), &join_type, None)?;
 
         Ok(Self::from(LogicalPlan::Join(Join {
             left: Arc::new(self.plan),
@@ -1023,6 +1025,7 @@ impl LogicalPlanBuilder {
             join_constraint: JoinConstraint::On,
             schema: DFSchemaRef::new(join_schema),
             null_equals_null: false,
+            projection: None,
         })))
     }
 
@@ -1038,6 +1041,7 @@ pub fn build_join_schema(
     left: &DFSchema,
     right: &DFSchema,
     join_type: &JoinType,
+    projection: Option<&Vec<Column>>,
 ) -> Result<DFSchema> {
     fn nullify_fields(fields: &[DFField]) -> Vec<DFField> {
         fields
@@ -1049,51 +1053,65 @@ pub fn build_join_schema(
     let right_fields = right.fields();
     let left_fields = left.fields();
 
-    let fields: Vec<DFField> = match join_type {
-        JoinType::Inner => {
-            // left then right
-            left_fields
+    let fields = {
+        if let Some(projection) = projection {
+            projection
                 .iter()
-                .chain(right_fields.iter())
-                .cloned()
-                .collect()
-        }
-        JoinType::Left => {
-            // left then right, right set to nullable in case of not matched scenario
-            left_fields
-                .iter()
-                .chain(&nullify_fields(right_fields))
-                .cloned()
-                .collect()
-        }
-        JoinType::Right => {
-            // left then right, left set to nullable in case of not matched scenario
-            nullify_fields(left_fields)
-                .iter()
-                .chain(right_fields.iter())
-                .cloned()
-                .collect()
-        }
-        JoinType::Full => {
-            // left then right, all set to nullable in case of not matched scenario
-            nullify_fields(left_fields)
-                .iter()
-                .chain(&nullify_fields(right_fields))
-                .cloned()
-                .collect()
-        }
-        JoinType::LeftSemi | JoinType::LeftAnti => {
-            // Only use the left side for the schema
-            left_fields.clone()
-        }
-        JoinType::RightSemi | JoinType::RightAnti => {
-            // Only use the right side for the schema
-            right_fields.clone()
+                .map(|col| {
+                    left.field_from_column(col)
+                        .or_else(|_| right.field_from_column(col))
+                        .cloned()
+                })
+                .collect::<Result<Vec<DFField>>>()?
+        } else {
+            match join_type {
+                JoinType::Inner => {
+                    // left then right
+                    left_fields
+                        .iter()
+                        .chain(right_fields.iter())
+                        .cloned()
+                        .collect()
+                }
+                JoinType::Left => {
+                    // left then right, right set to nullable in case of not matched scenario
+                    left_fields
+                        .iter()
+                        .chain(&nullify_fields(right_fields))
+                        .cloned()
+                        .collect()
+                }
+                JoinType::Right => {
+                    // left then right, left set to nullable in case of not matched scenario
+                    nullify_fields(left_fields)
+                        .iter()
+                        .chain(right_fields.iter())
+                        .cloned()
+                        .collect()
+                }
+                JoinType::Full => {
+                    // left then right, all set to nullable in case of not matched scenario
+                    nullify_fields(left_fields)
+                        .iter()
+                        .chain(&nullify_fields(right_fields))
+                        .cloned()
+                        .collect()
+                }
+                JoinType::LeftSemi | JoinType::LeftAnti => {
+                    // Only use the left side for the schema
+                    left_fields.clone()
+                }
+                JoinType::RightSemi | JoinType::RightAnti => {
+                    // Only use the right side for the schema
+                    right_fields.clone()
+                }
+            }
         }
     };
 
     let mut metadata = left.metadata().clone();
     metadata.extend(right.metadata().clone());
+
     DFSchema::new_with_metadata(fields, metadata)
 }
 
diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs
index ab45047acf..d4517917ec 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -1660,6 +1660,8 @@ pub struct Join {
     pub schema: DFSchemaRef,
     /// If null_equals_null is true, null == null else null != null
     pub null_equals_null: bool,
+    /// optional projection
+    pub projection: Option<Vec<Column>>,
 }
 
 impl Join {
@@ -1681,8 +1683,12 @@ impl Join {
             .zip(column_on.1.into_iter())
             .map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
             .collect();
-        let join_schema =
-            build_join_schema(left.schema(), right.schema(), &original_join.join_type)?;
+        let join_schema = build_join_schema(
+            left.schema(),
+            right.schema(),
+            &original_join.join_type,
+            original_join.projection.as_ref(),
+        )?;
 
         Ok(Join {
             left,
@@ -1693,6 +1699,7 @@ impl Join {
             join_constraint: original_join.join_constraint,
             schema: Arc::new(join_schema),
             null_equals_null: original_join.null_equals_null,
+            projection: None,
         })
     }
 }
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 2b6fc5793a..2ae116e900 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -845,10 +845,15 @@ pub fn from_plan(
             join_constraint,
             on,
             null_equals_null,
+            projection,
             ..
         }) => {
-            let schema =
-                build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?;
+            let schema = build_join_schema(
+                inputs[0].schema(),
+                inputs[1].schema(),
+                join_type,
+                projection.as_ref(),
+            )?;
 
             let equi_expr_count = on.len();
             assert!(expr.len() >= equi_expr_count);
@@ -881,6 +886,7 @@ pub fn from_plan(
                 filter: filter_expr,
                 schema: DFSchemaRef::new(schema),
                 null_equals_null: *null_equals_null,
+                projection: projection.clone(),
             }))
         }
         LogicalPlan::CrossJoin(_) => {
diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs
index 533566a0bf..e32311d67d 100644
--- a/datafusion/optimizer/src/eliminate_cross_join.rs
+++ b/datafusion/optimizer/src/eliminate_cross_join.rs
@@ -214,6 +214,7 @@ fn find_inner_join(
                 left_input.schema(),
                 right_input.schema(),
                 &JoinType::Inner,
+                None,
             )?);
 
             return Ok(LogicalPlan::Join(Join {
@@ -225,6 +226,7 @@ fn find_inner_join(
                 filter: None,
                 schema: join_schema,
                 null_equals_null: false,
+                projection: None,
             }));
         }
     }
@@ -233,6 +235,7 @@ fn find_inner_join(
         left_input.schema(),
         right.schema(),
         &JoinType::Inner,
+        None,
     )?);
 
     Ok(LogicalPlan::CrossJoin(CrossJoin {
diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs
index e4d57f0209..e2df09da56 100644
--- a/datafusion/optimizer/src/eliminate_outer_join.rs
+++ b/datafusion/optimizer/src/eliminate_outer_join.rs
@@ -105,6 +105,7 @@ impl OptimizerRule for EliminateOuterJoin {
                         filter: join.filter.clone(),
                         schema: join.schema.clone(),
                         null_equals_null: join.null_equals_null,
+                        projection: join.projection.clone(),
                     });
                     let new_plan = plan.with_new_inputs(&[new_join])?;
                     Ok(Some(new_plan))
diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs
index 20b9c62971..41db207d53 100644
--- a/datafusion/optimizer/src/extract_equijoin_predicate.rs
+++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs
@@ -55,6 +55,7 @@ impl OptimizerRule for ExtractEquijoinPredicate {
                 join_constraint,
                 schema,
                 null_equals_null,
+                projection,
             }) => {
                 let left_schema = left.schema();
                 let right_schema = right.schema();
@@ -80,6 +81,7 @@ impl OptimizerRule for ExtractEquijoinPredicate {
                             join_constraint: *join_constraint,
                             schema: schema.clone(),
                             null_equals_null: *null_equals_null,
+                            projection: projection.clone(),
                         })
                     });
 
diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs
index 6703a1d787..42915b2e3c 100644
--- a/datafusion/optimizer/src/push_down_limit.rs
+++ b/datafusion/optimizer/src/push_down_limit.rs
@@ -266,6 +266,7 @@ fn push_down_join(join: &Join, limit: usize) -> Option<Join> {
                 join_constraint: join.join_constraint,
                 schema: join.schema.clone(),
                 null_equals_null: join.null_equals_null,
+                projection: join.projection.clone(),
             })
         }
     }
diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs
index 4773a944f4..65409b0bd0 100644
--- a/datafusion/optimizer/src/push_down_projection.rs
+++ b/datafusion/optimizer/src/push_down_projection.rs
@@ -101,6 +101,10 @@ impl OptimizerRule for PushDownProjection {
                 for e in projection.expr.iter() {
                     expr_to_columns(e, &mut push_columns)?;
                 }
+
+                // Keep columns to use for join output projection
+                let output_columns = push_columns.clone();
+
                 for (l, r) in join.on.iter() {
                     expr_to_columns(l, &mut push_columns)?;
                     expr_to_columns(r, &mut push_columns)?;
@@ -119,9 +123,14 @@ impl OptimizerRule for PushDownProjection {
                     join.right.schema(),
                     join.right.clone(),
                 )?;
-                let new_join = child_plan.with_new_inputs(&[new_left, new_right])?;
 
-                generate_plan!(projection_is_empty, plan, new_join)
+                let mut join = join.clone();
+
+                join.left = Arc::new(new_left);
+                join.right = Arc::new(new_right);
+                join.projection = Some(output_columns.into_iter().collect());
+
+                generate_plan!(projection_is_empty, plan, LogicalPlan::Join(join))
             }
             LogicalPlan::CrossJoin(join) => {
                 // collect column in on/filter in join and projection.