You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/11/06 11:05:09 UTC

[arrow-datafusion] branch master updated: fix comments (#1135)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 2b24b89  fix comments (#1135)
2b24b89 is described below

commit 2b24b892058d2dc87a29e931c1e81d621a60942e
Author: Carlos <wx...@gmail.com>
AuthorDate: Sat Nov 6 19:05:00 2021 +0800

    fix comments (#1135)
---
 ballista/rust/core/proto/ballista.proto            |   2 +
 .../rust/core/src/serde/logical_plan/to_proto.rs   |   2 +
 .../core/src/serde/physical_plan/from_proto.rs     |   1 +
 ballista/rust/core/src/serde/physical_plan/mod.rs  |   1 +
 .../rust/core/src/serde/physical_plan/to_proto.rs  |   1 +
 datafusion/src/logical_plan/builder.rs             |  14 ++
 datafusion/src/logical_plan/plan.rs                |   2 +
 datafusion/src/optimizer/projection_push_down.rs   |   2 +
 datafusion/src/optimizer/utils.rs                  |   2 +
 .../physical_optimizer/hash_build_probe_order.rs   |   3 +
 datafusion/src/physical_plan/hash_join.rs          | 212 ++++++++++++++++-----
 datafusion/src/physical_plan/planner.rs            |   3 +
 datafusion/src/sql/planner.rs                      |  26 ++-
 datafusion/tests/sql.rs                            |  65 +++++++
 14 files changed, 282 insertions(+), 54 deletions(-)

diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto
index 62b3185..1815811 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -435,6 +435,7 @@ message JoinNode {
   JoinConstraint join_constraint = 4;
   repeated Column left_join_column = 5;
   repeated Column right_join_column = 6;
+  bool null_equals_null = 7;
 }
 
 message CrossJoinNode {
@@ -649,6 +650,7 @@ message HashJoinExecNode {
   repeated JoinOn on = 3;
   JoinType join_type = 4;
   PartitionMode partition_mode = 6;
+  bool null_equals_null = 7;
 }
 
 message CrossJoinExecNode {
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index 1d1d48e..80af698 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -850,6 +850,7 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
                 on,
                 join_type,
                 join_constraint,
+                null_equals_null,
                 ..
             } => {
                 let left: protobuf::LogicalPlanNode = left.as_ref().try_into()?;
@@ -868,6 +869,7 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
                             join_constraint: join_constraint.into(),
                             left_join_column,
                             right_join_column,
+                            null_equals_null: *null_equals_null,
                         },
                     ))),
                 })
diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
index 99d2de0..3c05957 100644
--- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
@@ -365,6 +365,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
                     on,
                     &join_type.into(),
                     partition_mode,
+                    &hashjoin.null_equals_null,
                 )?))
             }
             PhysicalPlanType::CrossJoin(crossjoin) => {
diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs
index 4bf013a..aca8f64 100644
--- a/ballista/rust/core/src/serde/physical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/physical_plan/mod.rs
@@ -107,6 +107,7 @@ mod roundtrip_tests {
                     on.clone(),
                     join_type,
                     *partition_mode,
+                    &false,
                 )?))?;
             }
         }
diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
index afbb02a..41484db 100644
--- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
@@ -162,6 +162,7 @@ impl TryInto<protobuf::PhysicalPlanNode> for Arc<dyn ExecutionPlan> {
                         on,
                         join_type: join_type.into(),
                         partition_mode: partition_mode.into(),
+                        null_equals_null: *exec.null_equals_null(),
                     },
                 ))),
             })
diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs
index a9d814f..ac8e0c3 100644
--- a/datafusion/src/logical_plan/builder.rs
+++ b/datafusion/src/logical_plan/builder.rs
@@ -493,6 +493,18 @@ impl LogicalPlanBuilder {
         join_type: JoinType,
         join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>),
     ) -> Result<Self> {
+        self.join_detailed(right, join_type, join_keys, false)
+    }
+
+    /// Apply a join with on constraint and specified null equality
+    /// If null_equals_null is true then null == null, else null != null
+    pub fn join_detailed(
+        &self,
+        right: &LogicalPlan,
+        join_type: JoinType,
+        join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>),
+        null_equals_null: bool,
+    ) -> Result<Self> {
         if join_keys.0.len() != join_keys.1.len() {
             return Err(DataFusionError::Plan(
                 "left_keys and right_keys were not the same length".to_string(),
@@ -580,6 +592,7 @@ impl LogicalPlanBuilder {
             join_type,
             join_constraint: JoinConstraint::On,
             schema: DFSchemaRef::new(join_schema),
+            null_equals_null,
         }))
     }
 
@@ -611,6 +624,7 @@ impl LogicalPlanBuilder {
             join_type,
             join_constraint: JoinConstraint::Using,
             schema: DFSchemaRef::new(join_schema),
+            null_equals_null: false,
         }))
     }
 
diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs
index 13921d5..f53d8e6 100644
--- a/datafusion/src/logical_plan/plan.rs
+++ b/datafusion/src/logical_plan/plan.rs
@@ -135,6 +135,8 @@ pub enum LogicalPlan {
         join_constraint: JoinConstraint,
         /// The output schema, containing fields from the left and right inputs
         schema: DFSchemaRef,
+        /// If null_equals_null is true, null == null else null != null
+        null_equals_null: bool,
     },
     /// Apply Cross Join to two logical plans
     CrossJoin {
diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs
index 4fabc4f..78ea278 100644
--- a/datafusion/src/optimizer/projection_push_down.rs
+++ b/datafusion/src/optimizer/projection_push_down.rs
@@ -195,6 +195,7 @@ fn optimize_plan(
             on,
             join_type,
             join_constraint,
+            null_equals_null,
             ..
         } => {
             for (l, r) in on {
@@ -231,6 +232,7 @@ fn optimize_plan(
                 join_constraint: *join_constraint,
                 on: on.clone(),
                 schema: DFSchemaRef::new(schema),
+                null_equals_null: *null_equals_null,
             })
         }
         LogicalPlan::Window {
diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs
index f36330e..8ce0d3a 100644
--- a/datafusion/src/optimizer/utils.rs
+++ b/datafusion/src/optimizer/utils.rs
@@ -198,6 +198,7 @@ pub fn from_plan(
             join_type,
             join_constraint,
             on,
+            null_equals_null,
             ..
         } => {
             let schema =
@@ -209,6 +210,7 @@ pub fn from_plan(
                 join_constraint: *join_constraint,
                 on: on.clone(),
                 schema: DFSchemaRef::new(schema),
+                null_equals_null: *null_equals_null,
             })
         }
         LogicalPlan::CrossJoin { .. } => {
diff --git a/datafusion/src/physical_optimizer/hash_build_probe_order.rs b/datafusion/src/physical_optimizer/hash_build_probe_order.rs
index 0b87ceb..0d1c39f 100644
--- a/datafusion/src/physical_optimizer/hash_build_probe_order.rs
+++ b/datafusion/src/physical_optimizer/hash_build_probe_order.rs
@@ -123,6 +123,7 @@ impl PhysicalOptimizerRule for HashBuildProbeOrder {
                         .collect(),
                     &swap_join_type(*hash_join.join_type()),
                     *hash_join.partition_mode(),
+                    hash_join.null_equals_null(),
                 )?;
                 let proj = ProjectionExec::try_new(
                     swap_reverting_projection(&*left.schema(), &*right.schema()),
@@ -195,6 +196,7 @@ mod tests {
             )],
             &JoinType::Left,
             PartitionMode::CollectLeft,
+            &false,
         )
         .unwrap();
 
@@ -238,6 +240,7 @@ mod tests {
             )],
             &JoinType::Left,
             PartitionMode::CollectLeft,
+            &false,
         )
         .unwrap();
 
diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs
index 2ed0faa..727d1c6 100644
--- a/datafusion/src/physical_plan/hash_join.rs
+++ b/datafusion/src/physical_plan/hash_join.rs
@@ -120,6 +120,8 @@ pub struct HashJoinExec {
     metrics: ExecutionPlanMetricsSet,
     /// Information of index and left / right placement of columns
     column_indices: Vec<ColumnIndex>,
+    /// If null_equals_null is true, null == null else null != null
+    null_equals_null: bool,
 }
 
 /// Metrics for HashJoinExec
@@ -180,6 +182,7 @@ impl HashJoinExec {
         on: JoinOn,
         join_type: &JoinType,
         partition_mode: PartitionMode,
+        null_equals_null: &bool,
     ) -> Result<Self> {
         let left_schema = left.schema();
         let right_schema = right.schema();
@@ -201,6 +204,7 @@ impl HashJoinExec {
             mode: partition_mode,
             metrics: ExecutionPlanMetricsSet::new(),
             column_indices,
+            null_equals_null: *null_equals_null,
         })
     }
 
@@ -228,6 +232,11 @@ impl HashJoinExec {
     pub fn partition_mode(&self) -> &PartitionMode {
         &self.mode
     }
+
+    /// Get null_equals_null
+    pub fn null_equals_null(&self) -> &bool {
+        &self.null_equals_null
+    }
 }
 
 #[async_trait]
@@ -255,6 +264,7 @@ impl ExecutionPlan for HashJoinExec {
                 self.on.clone(),
                 &self.join_type,
                 self.mode,
+                &self.null_equals_null,
             )?)),
             _ => Err(DataFusionError::Internal(
                 "HashJoinExec wrong number of children".to_string(),
@@ -406,6 +416,7 @@ impl ExecutionPlan for HashJoinExec {
             self.random_state.clone(),
             visited_left_side,
             HashJoinMetrics::new(partition, &self.metrics),
+            self.null_equals_null,
         )))
     }
 
@@ -499,6 +510,8 @@ struct HashJoinStream {
     join_metrics: HashJoinMetrics,
     /// Information of index and left / right placement of columns
     column_indices: Vec<ColumnIndex>,
+    /// If null_equals_null is true, null == null else null != null
+    null_equals_null: bool,
 }
 
 #[allow(clippy::too_many_arguments)]
@@ -514,6 +527,7 @@ impl HashJoinStream {
         random_state: RandomState,
         visited_left_side: Vec<bool>,
         join_metrics: HashJoinMetrics,
+        null_equals_null: bool,
     ) -> Self {
         HashJoinStream {
             schema,
@@ -527,6 +541,7 @@ impl HashJoinStream {
             visited_left_side,
             is_exhausted: false,
             join_metrics,
+            null_equals_null,
         }
     }
 }
@@ -581,10 +596,18 @@ fn build_batch(
     schema: &Schema,
     column_indices: &[ColumnIndex],
     random_state: &RandomState,
+    null_equals_null: &bool,
 ) -> ArrowResult<(RecordBatch, UInt64Array)> {
-    let (left_indices, right_indices) =
-        build_join_indexes(left_data, batch, join_type, on_left, on_right, random_state)
-            .unwrap();
+    let (left_indices, right_indices) = build_join_indexes(
+        left_data,
+        batch,
+        join_type,
+        on_left,
+        on_right,
+        random_state,
+        null_equals_null,
+    )
+    .unwrap();
 
     if matches!(join_type, JoinType::Semi | JoinType::Anti) {
         return Ok((
@@ -637,6 +660,7 @@ fn build_join_indexes(
     left_on: &[Column],
     right_on: &[Column],
     random_state: &RandomState,
+    null_equals_null: &bool,
 ) -> Result<(UInt64Array, UInt32Array)> {
     let keys_values = right_on
         .iter()
@@ -668,7 +692,13 @@ fn build_join_indexes(
                 {
                     for &i in indices {
                         // Check hash collisions
-                        if equal_rows(i as usize, row, &left_join_values, &keys_values)? {
+                        if equal_rows(
+                            i as usize,
+                            row,
+                            &left_join_values,
+                            &keys_values,
+                            *null_equals_null,
+                        )? {
                             left_indices.append(i);
                             right_indices.append(row as u32);
                         }
@@ -702,7 +732,13 @@ fn build_join_indexes(
                 {
                     for &i in indices {
                         // Collision check
-                        if equal_rows(i as usize, row, &left_join_values, &keys_values)? {
+                        if equal_rows(
+                            i as usize,
+                            row,
+                            &left_join_values,
+                            &keys_values,
+                            *null_equals_null,
+                        )? {
                             left_indices.append_value(i)?;
                             right_indices.append_value(row as u32)?;
                         }
@@ -725,6 +761,7 @@ fn build_join_indexes(
                                 row,
                                 &left_join_values,
                                 &keys_values,
+                                *null_equals_null,
                             )? {
                                 left_indices.append_value(i)?;
                                 right_indices.append_value(row as u32)?;
@@ -751,12 +788,13 @@ fn build_join_indexes(
 }
 
 macro_rules! equal_rows_elem {
-    ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident) => {{
+    ($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{
         let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap();
         let right_array = $r.as_any().downcast_ref::<$array_type>().unwrap();
 
         match (left_array.is_null($left), right_array.is_null($right)) {
             (false, false) => left_array.value($left) == right_array.value($right),
+            (true, true) => $null_equals_null,
             _ => false,
         }
     }};
@@ -768,6 +806,7 @@ fn equal_rows(
     right: usize,
     left_arrays: &[ArrayRef],
     right_arrays: &[ArrayRef],
+    null_equals_null: bool,
 ) -> Result<bool> {
     let mut err = None;
     let res = left_arrays
@@ -775,33 +814,87 @@ fn equal_rows(
         .zip(right_arrays)
         .all(|(l, r)| match l.data_type() {
             DataType::Null => true,
-            DataType::Boolean => equal_rows_elem!(BooleanArray, l, r, left, right),
-            DataType::Int8 => equal_rows_elem!(Int8Array, l, r, left, right),
-            DataType::Int16 => equal_rows_elem!(Int16Array, l, r, left, right),
-            DataType::Int32 => equal_rows_elem!(Int32Array, l, r, left, right),
-            DataType::Int64 => equal_rows_elem!(Int64Array, l, r, left, right),
-            DataType::UInt8 => equal_rows_elem!(UInt8Array, l, r, left, right),
-            DataType::UInt16 => equal_rows_elem!(UInt16Array, l, r, left, right),
-            DataType::UInt32 => equal_rows_elem!(UInt32Array, l, r, left, right),
-            DataType::UInt64 => equal_rows_elem!(UInt64Array, l, r, left, right),
-            DataType::Float32 => equal_rows_elem!(Float32Array, l, r, left, right),
-            DataType::Float64 => equal_rows_elem!(Float64Array, l, r, left, right),
+            DataType::Boolean => {
+                equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null)
+            }
+            DataType::Int8 => {
+                equal_rows_elem!(Int8Array, l, r, left, right, null_equals_null)
+            }
+            DataType::Int16 => {
+                equal_rows_elem!(Int16Array, l, r, left, right, null_equals_null)
+            }
+            DataType::Int32 => {
+                equal_rows_elem!(Int32Array, l, r, left, right, null_equals_null)
+            }
+            DataType::Int64 => {
+                equal_rows_elem!(Int64Array, l, r, left, right, null_equals_null)
+            }
+            DataType::UInt8 => {
+                equal_rows_elem!(UInt8Array, l, r, left, right, null_equals_null)
+            }
+            DataType::UInt16 => {
+                equal_rows_elem!(UInt16Array, l, r, left, right, null_equals_null)
+            }
+            DataType::UInt32 => {
+                equal_rows_elem!(UInt32Array, l, r, left, right, null_equals_null)
+            }
+            DataType::UInt64 => {
+                equal_rows_elem!(UInt64Array, l, r, left, right, null_equals_null)
+            }
+            DataType::Float32 => {
+                equal_rows_elem!(Float32Array, l, r, left, right, null_equals_null)
+            }
+            DataType::Float64 => {
+                equal_rows_elem!(Float64Array, l, r, left, right, null_equals_null)
+            }
             DataType::Timestamp(time_unit, None) => match time_unit {
                 TimeUnit::Second => {
-                    equal_rows_elem!(TimestampSecondArray, l, r, left, right)
+                    equal_rows_elem!(
+                        TimestampSecondArray,
+                        l,
+                        r,
+                        left,
+                        right,
+                        null_equals_null
+                    )
                 }
                 TimeUnit::Millisecond => {
-                    equal_rows_elem!(TimestampMillisecondArray, l, r, left, right)
+                    equal_rows_elem!(
+                        TimestampMillisecondArray,
+                        l,
+                        r,
+                        left,
+                        right,
+                        null_equals_null
+                    )
                 }
                 TimeUnit::Microsecond => {
-                    equal_rows_elem!(TimestampMicrosecondArray, l, r, left, right)
+                    equal_rows_elem!(
+                        TimestampMicrosecondArray,
+                        l,
+                        r,
+                        left,
+                        right,
+                        null_equals_null
+                    )
                 }
                 TimeUnit::Nanosecond => {
-                    equal_rows_elem!(TimestampNanosecondArray, l, r, left, right)
+                    equal_rows_elem!(
+                        TimestampNanosecondArray,
+                        l,
+                        r,
+                        left,
+                        right,
+                        null_equals_null
+                    )
                 }
             },
-            DataType::Utf8 => equal_rows_elem!(StringArray, l, r, left, right),
-            DataType::LargeUtf8 => equal_rows_elem!(LargeStringArray, l, r, left, right),
+            DataType::Utf8 => {
+                equal_rows_elem!(StringArray, l, r, left, right, null_equals_null)
+            }
+            DataType::LargeUtf8 => {
+                equal_rows_elem!(LargeStringArray, l, r, left, right, null_equals_null)
+            }
             _ => {
                 // This is internal because we should have caught this before.
                 err = Some(Err(DataFusionError::Internal(
@@ -883,6 +976,7 @@ impl Stream for HashJoinStream {
                         &self.schema,
                         &self.column_indices,
                         &self.random_state,
+                        &self.null_equals_null,
                     );
                     self.join_metrics.input_batches.add(1);
                     self.join_metrics.input_rows.add(batch.num_rows());
@@ -976,8 +1070,16 @@ mod tests {
         right: Arc<dyn ExecutionPlan>,
         on: JoinOn,
         join_type: &JoinType,
+        null_equals_null: bool,
     ) -> Result<HashJoinExec> {
-        HashJoinExec::try_new(left, right, on, join_type, PartitionMode::CollectLeft)
+        HashJoinExec::try_new(
+            left,
+            right,
+            on,
+            join_type,
+            PartitionMode::CollectLeft,
+            &null_equals_null,
+        )
     }
 
     async fn join_collect(
@@ -985,8 +1087,9 @@ mod tests {
         right: Arc<dyn ExecutionPlan>,
         on: JoinOn,
         join_type: &JoinType,
+        null_equals_null: bool,
     ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
-        let join = join(left, right, on, join_type)?;
+        let join = join(left, right, on, join_type, null_equals_null)?;
         let columns = columns(&join.schema());
 
         let stream = join.execute(0).await?;
@@ -1000,6 +1103,7 @@ mod tests {
         right: Arc<dyn ExecutionPlan>,
         on: JoinOn,
         join_type: &JoinType,
+        null_equals_null: bool,
     ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
         let partition_count = 4;
 
@@ -1025,6 +1129,7 @@ mod tests {
             on,
             join_type,
             PartitionMode::Partitioned,
+            &null_equals_null,
         )?;
 
         let columns = columns(&join.schema());
@@ -1062,9 +1167,14 @@ mod tests {
             Column::new_with_schema("b1", &right.schema())?,
         )];
 
-        let (columns, batches) =
-            join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Inner)
-                .await?;
+        let (columns, batches) = join_collect(
+            left.clone(),
+            right.clone(),
+            on.clone(),
+            &JoinType::Inner,
+            false,
+        )
+        .await?;
 
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
 
@@ -1104,6 +1214,7 @@ mod tests {
             right.clone(),
             on.clone(),
             &JoinType::Inner,
+            false,
         )
         .await?;
 
@@ -1140,7 +1251,8 @@ mod tests {
             Column::new_with_schema("b2", &right.schema())?,
         )];
 
-        let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?;
+        let (columns, batches) =
+            join_collect(left, right, on, &JoinType::Inner, false).await?;
 
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
 
@@ -1182,7 +1294,8 @@ mod tests {
             ),
         ];
 
-        let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?;
+        let (columns, batches) =
+            join_collect(left, right, on, &JoinType::Inner, false).await?;
 
         assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]);
 
@@ -1234,7 +1347,8 @@ mod tests {
             ),
         ];
 
-        let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?;
+        let (columns, batches) =
+            join_collect(left, right, on, &JoinType::Inner, false).await?;
 
         assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]);
 
@@ -1281,7 +1395,7 @@ mod tests {
             Column::new_with_schema("b1", &right.schema())?,
         )];
 
-        let join = join(left, right, on, &JoinType::Inner)?;
+        let join = join(left, right, on, &JoinType::Inner, false)?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
@@ -1347,7 +1461,7 @@ mod tests {
             Column::new_with_schema("b1", &right.schema()).unwrap(),
         )];
 
-        let join = join(left, right, on, &JoinType::Left).unwrap();
+        let join = join(left, right, on, &JoinType::Left, false).unwrap();
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
@@ -1388,7 +1502,7 @@ mod tests {
             Column::new_with_schema("b2", &right.schema()).unwrap(),
         )];
 
-        let join = join(left, right, on, &JoinType::Full).unwrap();
+        let join = join(left, right, on, &JoinType::Full, false).unwrap();
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
@@ -1427,7 +1541,7 @@ mod tests {
         )];
         let schema = right.schema();
         let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap());
-        let join = join(left, right, on, &JoinType::Left).unwrap();
+        let join = join(left, right, on, &JoinType::Left, false).unwrap();
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
@@ -1462,7 +1576,7 @@ mod tests {
         )];
         let schema = right.schema();
         let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap());
-        let join = join(left, right, on, &JoinType::Full).unwrap();
+        let join = join(left, right, on, &JoinType::Full, false).unwrap();
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
@@ -1500,9 +1614,14 @@ mod tests {
             Column::new_with_schema("b1", &right.schema())?,
         )];
 
-        let (columns, batches) =
-            join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Left)
-                .await?;
+        let (columns, batches) = join_collect(
+            left.clone(),
+            right.clone(),
+            on.clone(),
+            &JoinType::Left,
+            false,
+        )
+        .await?;
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
 
         let expected = vec![
@@ -1541,6 +1660,7 @@ mod tests {
             right.clone(),
             on.clone(),
             &JoinType::Left,
+            false,
         )
         .await?;
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
@@ -1576,7 +1696,7 @@ mod tests {
             Column::new_with_schema("b1", &right.schema())?,
         )];
 
-        let join = join(left, right, on, &JoinType::Semi)?;
+        let join = join(left, right, on, &JoinType::Semi, false)?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1"]);
@@ -1615,7 +1735,7 @@ mod tests {
             Column::new_with_schema("b1", &right.schema())?,
         )];
 
-        let join = join(left, right, on, &JoinType::Anti)?;
+        let join = join(left, right, on, &JoinType::Anti, false)?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1"]);
@@ -1652,7 +1772,8 @@ mod tests {
             Column::new_with_schema("b1", &right.schema())?,
         )];
 
-        let (columns, batches) = join_collect(left, right, on, &JoinType::Right).await?;
+        let (columns, batches) =
+            join_collect(left, right, on, &JoinType::Right, false).await?;
 
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
 
@@ -1689,7 +1810,7 @@ mod tests {
         )];
 
         let (columns, batches) =
-            partitioned_join_collect(left, right, on, &JoinType::Right).await?;
+            partitioned_join_collect(left, right, on, &JoinType::Right, false).await?;
 
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
 
@@ -1725,7 +1846,7 @@ mod tests {
             Column::new_with_schema("b2", &right.schema()).unwrap(),
         )];
 
-        let join = join(left, right, on, &JoinType::Full)?;
+        let join = join(left, right, on, &JoinType::Full, false)?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
@@ -1780,6 +1901,7 @@ mod tests {
             &[Column::new("a", 0)],
             &[Column::new("a", 0)],
             &random_state,
+            &false,
         )?;
 
         let mut left_ids = UInt64Builder::new(0);
@@ -1815,7 +1937,7 @@ mod tests {
             Column::new_with_schema("b", &right.schema()).unwrap(),
         )];
 
-        let join = join(left, right, on, &JoinType::Inner)?;
+        let join = join(left, right, on, &JoinType::Inner, false)?;
 
         let columns = columns(&join.schema());
         assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]);
diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs
index 87fc166..12e7401 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -713,6 +713,7 @@ impl DefaultPhysicalPlanner {
                     right,
                     on: keys,
                     join_type,
+                    null_equals_null,
                     ..
                 } => {
                     let left_df_schema = left.schema();
@@ -761,6 +762,7 @@ impl DefaultPhysicalPlanner {
                             join_on,
                             join_type,
                             PartitionMode::Partitioned,
+                            null_equals_null,
                         )?))
                     } else {
                         Ok(Arc::new(HashJoinExec::try_new(
@@ -769,6 +771,7 @@ impl DefaultPhysicalPlanner {
                             join_on,
                             join_type,
                             PartitionMode::CollectLeft,
+                            null_equals_null,
                         )?))
                     }
                 }
diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs
index 1653cb5..de09632 100644
--- a/datafusion/src/sql/planner.rs
+++ b/datafusion/src/sql/planner.rs
@@ -192,23 +192,31 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 left,
                 right,
                 all,
-            } => match (op, all) {
+            } => {
+                let left_plan = self.set_expr_to_plan(left.as_ref(), None, ctes)?;
+                let right_plan = self.set_expr_to_plan(right.as_ref(), None, ctes)?;
+                match (op, all) {
                 (SetOperator::Union, true) => {
-                    let left_plan = self.set_expr_to_plan(left.as_ref(), None, ctes)?;
-                    let right_plan = self.set_expr_to_plan(right.as_ref(), None, ctes)?;
                     union_with_alias(left_plan, right_plan, alias)
                 }
                 (SetOperator::Union, false) => {
-                    let left_plan = self.set_expr_to_plan(left.as_ref(), None, ctes)?;
-                    let right_plan = self.set_expr_to_plan(right.as_ref(), None, ctes)?;
                     let union_plan = union_with_alias(left_plan, right_plan, alias)?;
                     LogicalPlanBuilder::from(union_plan).distinct()?.build()
                 }
+                (SetOperator::Intersect, true) => {
+                    let join_keys = left_plan.schema().fields().iter().zip(right_plan.schema().fields().iter()).map(|(left_field, right_field)| ((Column::from_name(left_field.name())), (Column::from_name(right_field.name())))).unzip();
+                    LogicalPlanBuilder::from(left_plan).join_detailed(&right_plan, JoinType::Semi, join_keys, true)?.build()
+                }
+                (SetOperator::Intersect, false) => {
+                    let join_keys = left_plan.schema().fields().iter().zip(right_plan.schema().fields().iter()).map(|(left_field, right_field)| ((Column::from_name(left_field.name())), (Column::from_name(right_field.name())))).unzip();
+                    LogicalPlanBuilder::from(left_plan).distinct()?.join_detailed(&right_plan, JoinType::Semi, join_keys, true)?.build()
+                }
                 _ => Err(DataFusionError::NotImplemented(format!(
-                    "Only UNION ALL and UNION [DISTINCT] are supported, found {}",
+                    "Only UNION ALL and UNION [DISTINCT] and INTERSECT and INTERSECT [DISTINCT] are supported, found {}",
                     op
                 ))),
-            },
+                }
+            }
             _ => Err(DataFusionError::NotImplemented(format!(
                 "Query {} not implemented yet",
                 set_expr
@@ -3543,11 +3551,11 @@ mod tests {
     }
 
     #[test]
-    fn only_union_all_supported() {
+    fn except_not_supported() {
         let sql = "SELECT order_id from orders EXCEPT SELECT order_id FROM orders";
         let err = logical_plan(sql).expect_err("query should have failed");
         assert_eq!(
-            "NotImplemented(\"Only UNION ALL and UNION [DISTINCT] are supported, found EXCEPT\")",
+            "NotImplemented(\"Only UNION ALL and UNION [DISTINCT] and INTERSECT and INTERSECT [DISTINCT] are supported, found EXCEPT\")",
             format!("{:?}", err)
         );
     }
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index 3af5501..996908a 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -5580,3 +5580,68 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> {
     assert_eq!(expected, actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn intersect_with_null_not_equal() {
+    let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1
+            INTERSECT SELECT * FROM (SELECT null AS id1, 2 AS id2) t2";
+
+    let expected: &[&[&str]] = &[];
+
+    let mut ctx = create_join_context_qualified().unwrap();
+    let actual = execute(&mut ctx, sql).await;
+
+    assert_eq!(expected, actual);
+}
+
+#[tokio::test]
+async fn intersect_with_null_equal() {
+    let sql = "SELECT * FROM (SELECT null AS id1, 1 AS id2) t1
+            INTERSECT SELECT * FROM (SELECT null AS id1, 1 AS id2) t2";
+
+    let expected: Vec<Vec<String>> = vec![vec!["NULL".to_string(), "1".to_string()]];
+
+    let mut ctx = create_join_context_qualified().unwrap();
+    let actual = execute(&mut ctx, sql).await;
+
+    assert_eq!(expected, actual);
+}
+
+#[tokio::test]
+async fn test_intersect_all() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+    register_alltypes_parquet(&mut ctx).await;
+    // execute the query
+    let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 INTERSECT ALL SELECT int_col, double_col FROM alltypes_plain LIMIT 4";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    let expected = vec![
+        "+---------+------------+",
+        "| int_col | double_col |",
+        "+---------+------------+",
+        "| 1       | 10.1       |",
+        "| 1       | 10.1       |",
+        "| 1       | 10.1       |",
+        "| 1       | 10.1       |",
+        "+---------+------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_intersect_distinct() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+    register_alltypes_parquet(&mut ctx).await;
+    // execute the query
+    let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 INTERSECT SELECT int_col, double_col FROM alltypes_plain";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    let expected = vec![
+        "+---------+------------+",
+        "| int_col | double_col |",
+        "+---------+------------+",
+        "| 1       | 10.1       |",
+        "+---------+------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}