You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by dh...@apache.org on 2022/10/30 15:39:09 UTC

[arrow-datafusion] branch master updated: Add right anti join support and support it in HashBuildProbeOrder (#4011)

This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new d6f0b1261 Add right anti join support and support it in HashBuildProbeOrder (#4011)
d6f0b1261 is described below

commit d6f0b1261159f6f97cb14357489d9081e8bbd7ab
Author: Daniƫl Heres <da...@gmail.com>
AuthorDate: Sun Oct 30 16:39:03 2022 +0100

    Add right anti join support and support it in HashBuildProbeOrder (#4011)
    
    * Add right anti join support
    
    * Fix
---
 .../physical_optimizer/hash_build_probe_order.rs   |  82 +++++++++-------
 .../core/src/physical_plan/joins/hash_join.rs      | 109 +++++++++++++++++++--
 .../src/physical_plan/joins/sort_merge_join.rs     |   6 +-
 datafusion/core/src/physical_plan/joins/utils.rs   |  10 +-
 datafusion/expr/src/logical_plan/builder.rs        |   5 +-
 datafusion/expr/src/logical_plan/plan.rs           |   3 +
 datafusion/optimizer/src/filter_push_down.rs       |   7 +-
 datafusion/proto/proto/datafusion.proto            |   1 +
 datafusion/proto/src/generated/pbjson.rs           |   3 +
 datafusion/proto/src/generated/prost.rs            |   2 +
 datafusion/proto/src/logical_plan.rs               |   2 +
 11 files changed, 178 insertions(+), 52 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/hash_build_probe_order.rs b/datafusion/core/src/physical_optimizer/hash_build_probe_order.rs
index 8b9279e2c..3a7e0dad2 100644
--- a/datafusion/core/src/physical_optimizer/hash_build_probe_order.rs
+++ b/datafusion/core/src/physical_optimizer/hash_build_probe_order.rs
@@ -75,8 +75,9 @@ fn supports_swap(join_type: JoinType) -> bool {
         | JoinType::Right
         | JoinType::Full
         | JoinType::LeftSemi
-        | JoinType::RightSemi => true,
-        JoinType::LeftAnti => false,
+        | JoinType::RightSemi
+        | JoinType::LeftAnti
+        | JoinType::RightAnti => true,
     }
 }
 
@@ -88,7 +89,8 @@ fn swap_join_type(join_type: JoinType) -> JoinType {
         JoinType::Right => JoinType::Left,
         JoinType::LeftSemi => JoinType::RightSemi,
         JoinType::RightSemi => JoinType::LeftSemi,
-        _ => unreachable!(),
+        JoinType::LeftAnti => JoinType::RightAnti,
+        JoinType::RightAnti => JoinType::LeftAnti,
     }
 }
 
@@ -176,7 +178,10 @@ impl PhysicalOptimizerRule for HashBuildProbeOrder {
                 )?;
                 if matches!(
                     hash_join.join_type(),
-                    JoinType::LeftSemi | JoinType::RightSemi
+                    JoinType::LeftSemi
+                        | JoinType::RightSemi
+                        | JoinType::LeftAnti
+                        | JoinType::RightAnti
                 ) {
                     return Ok(Arc::new(new_join));
                 }
@@ -362,45 +367,48 @@ mod tests {
     }
 
     #[tokio::test]
-    async fn test_join_with_swap_left_semi() {
-        let (big, small) = create_big_and_small();
-
-        let join = HashJoinExec::try_new(
-            Arc::clone(&big),
-            Arc::clone(&small),
-            vec![(
-                Column::new_with_schema("big_col", &big.schema()).unwrap(),
-                Column::new_with_schema("small_col", &small.schema()).unwrap(),
-            )],
-            None,
-            &JoinType::LeftSemi,
-            PartitionMode::CollectLeft,
-            &false,
-        )
-        .unwrap();
+    async fn test_join_with_swap_semi() {
+        let join_types = [JoinType::LeftSemi, JoinType::LeftAnti];
+        for join_type in join_types {
+            let (big, small) = create_big_and_small();
+
+            let join = HashJoinExec::try_new(
+                Arc::clone(&big),
+                Arc::clone(&small),
+                vec![(
+                    Column::new_with_schema("big_col", &big.schema()).unwrap(),
+                    Column::new_with_schema("small_col", &small.schema()).unwrap(),
+                )],
+                None,
+                &join_type,
+                PartitionMode::CollectLeft,
+                &false,
+            )
+            .unwrap();
 
-        let original_schema = join.schema();
+            let original_schema = join.schema();
 
-        let optimized_join = HashBuildProbeOrder::new()
-            .optimize(Arc::new(join), &SessionConfig::new())
-            .unwrap();
+            let optimized_join = HashBuildProbeOrder::new()
+                .optimize(Arc::new(join), &SessionConfig::new())
+                .unwrap();
 
-        let swapped_join = optimized_join
-            .as_any()
-            .downcast_ref::<HashJoinExec>()
-            .expect(
-                "A proj is not required to swap columns back to their original order",
-            );
+            let swapped_join = optimized_join
+                .as_any()
+                .downcast_ref::<HashJoinExec>()
+                .expect(
+                    "A proj is not required to swap columns back to their original order",
+                );
 
-        assert_eq!(swapped_join.schema().fields().len(), 1);
+            assert_eq!(swapped_join.schema().fields().len(), 1);
 
-        assert_eq!(swapped_join.left().statistics().total_byte_size, Some(10));
-        assert_eq!(
-            swapped_join.right().statistics().total_byte_size,
-            Some(100000)
-        );
+            assert_eq!(swapped_join.left().statistics().total_byte_size, Some(10));
+            assert_eq!(
+                swapped_join.right().statistics().total_byte_size,
+                Some(100000)
+            );
 
-        assert_eq!(original_schema, swapped_join.schema());
+            assert_eq!(original_schema, swapped_join.schema());
+        }
     }
 
     /// Compare the input plan with the plan after running the probe order optimizer.
diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs
index b41c0df1e..99628d7da 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join.rs
@@ -789,7 +789,6 @@ fn build_join_indexes(
                             &keys_values,
                             *null_equals_null,
                         )? {
-                            left_indices.append(i);
                             right_indices.append(row as u32);
                             break;
                         }
@@ -813,6 +812,59 @@ fn build_join_indexes(
                 PrimitiveArray::<UInt32Type>::from(right),
             ))
         }
+        JoinType::RightAnti => {
+            let mut left_indices = UInt64BufferBuilder::new(0);
+            let mut right_indices = UInt32BufferBuilder::new(0);
+
+            // Visit all of the right rows
+            for (row, hash_value) in hash_values.iter().enumerate() {
+                // Get the hash and find it in the build index
+
+                // For every item on the left and right we check if it doesn't match
+                // This possibly contains rows with hash collisions,
+                // So we have to check here whether rows are equal or not
+                // We only produce one row if there is no match
+                let matches = left.0.get(*hash_value, |(hash, _)| *hash_value == *hash);
+                let mut no_match = true;
+                match matches {
+                    Some((_, indices)) => {
+                        for &i in indices {
+                            // Check hash collisions
+                            if equal_rows(
+                                i as usize,
+                                row,
+                                &left_join_values,
+                                &keys_values,
+                                *null_equals_null,
+                            )? {
+                                no_match = false;
+                                break;
+                            }
+                        }
+                    }
+                    None => no_match = true,
+                };
+                if no_match {
+                    right_indices.append(row as u32);
+                }
+            }
+
+            let left = ArrayData::builder(DataType::UInt64)
+                .len(left_indices.len())
+                .add_buffer(left_indices.finish())
+                .build()
+                .unwrap();
+            let right = ArrayData::builder(DataType::UInt32)
+                .len(right_indices.len())
+                .add_buffer(right_indices.finish())
+                .build()
+                .unwrap();
+
+            Ok((
+                PrimitiveArray::<UInt64Type>::from(left),
+                PrimitiveArray::<UInt32Type>::from(right),
+            ))
+        }
         JoinType::Left => {
             let mut left_indices = UInt64Builder::with_capacity(0);
             let mut right_indices = UInt32Builder::with_capacity(0);
@@ -887,7 +939,7 @@ fn apply_join_filter(
     right_indices: UInt32Array,
     filter: &JoinFilter,
 ) -> Result<(UInt64Array, UInt32Array)> {
-    if left_indices.is_empty() {
+    if left_indices.is_empty() && right_indices.is_empty() {
         return Ok((left_indices, right_indices));
     };
 
@@ -904,6 +956,7 @@ fn apply_join_filter(
         JoinType::Inner
         | JoinType::Left
         | JoinType::LeftAnti
+        | JoinType::RightAnti
         | JoinType::LeftSemi
         | JoinType::RightSemi => {
             // For both INNER and LEFT joins, input arrays contains only indices for matched data.
@@ -1342,9 +1395,10 @@ impl HashJoinStream {
 
                     buffer
                 }
-                JoinType::Inner | JoinType::Right | JoinType::RightSemi => {
-                    BooleanBufferBuilder::new(0)
-                }
+                JoinType::Inner
+                | JoinType::Right
+                | JoinType::RightSemi
+                | JoinType::RightAnti => BooleanBufferBuilder::new(0),
             }
         });
 
@@ -1381,7 +1435,10 @@ impl HashJoinStream {
                                     visited_left_side.set_bit(x as usize, true);
                                 });
                             }
-                            JoinType::Inner | JoinType::Right | JoinType::RightSemi => {}
+                            JoinType::Inner
+                            | JoinType::Right
+                            | JoinType::RightSemi
+                            | JoinType::RightAnti => {}
                         }
                     }
                     Some(result.map(|x| x.0))
@@ -1420,6 +1477,7 @@ impl HashJoinStream {
                         | JoinType::LeftSemi
                         | JoinType::RightSemi
                         | JoinType::LeftAnti
+                        | JoinType::RightAnti
                         | JoinType::Inner
                         | JoinType::Right => {}
                     }
@@ -2255,6 +2313,45 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn join_right_anti() -> Result<()> {
+        let session_ctx = SessionContext::new();
+        let task_ctx = session_ctx.task_ctx();
+        let right = build_table(
+            ("a1", &vec![1, 2, 2, 3, 5]),
+            ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right
+            ("c1", &vec![7, 8, 8, 9, 11]),
+        );
+        let left = build_table(
+            ("a2", &vec![10, 20, 30, 40]),
+            ("b2", &vec![4, 5, 6, 5]), // 5 is double on the right
+            ("c2", &vec![70, 80, 90, 100]),
+        );
+        let on = vec![(
+            Column::new_with_schema("b2", &left.schema())?,
+            Column::new_with_schema("b1", &right.schema())?,
+        )];
+
+        let join = join(left, right, on, &JoinType::RightAnti, false)?;
+
+        let columns = columns(&join.schema());
+        assert_eq!(columns, vec!["a1", "b1", "c1"]);
+
+        let stream = join.execute(0, task_ctx)?;
+        let batches = common::collect(stream).await?;
+
+        let expected = vec![
+            "+----+----+----+",
+            "| a1 | b1 | c1 |",
+            "+----+----+----+",
+            "| 3  | 7  | 9  |",
+            "| 5  | 7  | 11 |",
+            "+----+----+----+",
+        ];
+        assert_batches_sorted_eq!(expected, &batches);
+        Ok(())
+    }
+
     #[tokio::test]
     async fn join_anti_with_filter() -> Result<()> {
         let session_ctx = SessionContext::new();
diff --git a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs
index dfcab88c2..92392455e 100644
--- a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs
+++ b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs
@@ -139,7 +139,9 @@ impl ExecutionPlan for SortMergeJoinExec {
             | JoinType::Left
             | JoinType::LeftSemi
             | JoinType::LeftAnti => self.left.output_ordering(),
-            JoinType::Right | JoinType::RightSemi => self.right.output_ordering(),
+            JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
+                self.right.output_ordering()
+            }
             JoinType::Full => None,
         }
     }
@@ -187,7 +189,7 @@ impl ExecutionPlan for SortMergeJoinExec {
                 self.on.iter().map(|on| on.0.clone()).collect(),
                 self.on.iter().map(|on| on.1.clone()).collect(),
             ),
-            JoinType::Right | JoinType::RightSemi => (
+            JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => (
                 self.right.clone(),
                 self.left.clone(),
                 self.on.iter().map(|on| on.1.clone()).collect(),
diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs
index 1d1560478..d041e7dfb 100644
--- a/datafusion/core/src/physical_plan/joins/utils.rs
+++ b/datafusion/core/src/physical_plan/joins/utils.rs
@@ -172,6 +172,7 @@ fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) ->
         JoinType::LeftSemi => false, // doesn't introduce nulls
         JoinType::RightSemi => false, // doesn't introduce nulls
         JoinType::LeftAnti => false, // doesn't introduce nulls (or can it??)
+        JoinType::RightAnti => false, // doesn't introduce nulls (or can it??)
     };
 
     if force_nullable {
@@ -237,7 +238,7 @@ pub fn build_join_schema(
                 )
             })
             .unzip(),
-        JoinType::RightSemi => right
+        JoinType::RightSemi | JoinType::RightAnti => right
             .fields()
             .iter()
             .cloned()
@@ -410,9 +411,10 @@ fn estimate_join_cardinality(
             })
         }
 
-        JoinType::LeftSemi => None,
-        JoinType::LeftAnti => None,
-        JoinType::RightSemi => None,
+        JoinType::LeftSemi
+        | JoinType::RightSemi
+        | JoinType::LeftAnti
+        | JoinType::RightAnti => None,
     }
 }
 
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index 829aa6682..30b13c2f0 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -827,7 +827,10 @@ pub fn build_join_schema(
             // Only use the left side for the schema
             left.fields().clone()
         }
-        JoinType::RightSemi => right.fields().clone(),
+        JoinType::RightSemi | JoinType::RightAnti => {
+            // Only use the right side for the schema
+            right.fields().clone()
+        }
     };
 
     let mut metadata = left.metadata().clone();
diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs
index 38b3a789a..d65ed5228 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -982,6 +982,8 @@ pub enum JoinType {
     RightSemi,
     /// Left Anti Join
     LeftAnti,
+    /// Right Anti Join
+    RightAnti,
 }
 
 impl Display for JoinType {
@@ -994,6 +996,7 @@ impl Display for JoinType {
             JoinType::LeftSemi => "LeftSemi",
             JoinType::RightSemi => "RightSemi",
             JoinType::LeftAnti => "LeftAnti",
+            JoinType::RightAnti => "RightAnti",
         };
         write!(f, "{}", join_type)
     }
diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs
index 255732b70..0539a8962 100644
--- a/datafusion/optimizer/src/filter_push_down.rs
+++ b/datafusion/optimizer/src/filter_push_down.rs
@@ -179,7 +179,7 @@ fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
             JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)),
             // No columns from the left side of the join can be referenced in output
             // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
-            JoinType::RightSemi => Ok((false, true)),
+            JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)),
         },
         LogicalPlan::CrossJoin(_) => Ok((true, true)),
         _ => Err(DataFusionError::Internal(
@@ -198,7 +198,10 @@ fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
             JoinType::Left => Ok((false, true)),
             JoinType::Right => Ok((true, false)),
             JoinType::Full => Ok((false, false)),
-            JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi => {
+            JoinType::LeftSemi
+            | JoinType::LeftAnti
+            | JoinType::RightSemi
+            | JoinType::RightAnti => {
                 // filter_push_down does not yet support SEMI/ANTI joins with join conditions
                 Ok((false, false))
             }
diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto
index 78b5669ed..e0f8f5160 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -233,6 +233,7 @@ enum JoinType {
   LEFTSEMI = 4;
   LEFTANTI = 5;
   RIGHTSEMI = 6;
+  RIGHTANTI = 7;
 }
 
 enum JoinConstraint {
diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs
index e8110b431..71502f812 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -6405,6 +6405,7 @@ impl serde::Serialize for JoinType {
             Self::Leftsemi => "LEFTSEMI",
             Self::Leftanti => "LEFTANTI",
             Self::Rightsemi => "RIGHTSEMI",
+            Self::Rightanti => "RIGHTANTI",
         };
         serializer.serialize_str(variant)
     }
@@ -6423,6 +6424,7 @@ impl<'de> serde::Deserialize<'de> for JoinType {
             "LEFTSEMI",
             "LEFTANTI",
             "RIGHTSEMI",
+            "RIGHTANTI",
         ];
 
         struct GeneratedVisitor;
@@ -6472,6 +6474,7 @@ impl<'de> serde::Deserialize<'de> for JoinType {
                     "LEFTSEMI" => Ok(JoinType::Leftsemi),
                     "LEFTANTI" => Ok(JoinType::Leftanti),
                     "RIGHTSEMI" => Ok(JoinType::Rightsemi),
+                    "RIGHTANTI" => Ok(JoinType::Rightanti),
                     _ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
                 }
             }
diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs
index cf6978faf..8287d1289 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1113,6 +1113,7 @@ pub enum JoinType {
     Leftsemi = 4,
     Leftanti = 5,
     Rightsemi = 6,
+    Rightanti = 7,
 }
 impl JoinType {
     /// String value of the enum field names used in the ProtoBuf definition.
@@ -1128,6 +1129,7 @@ impl JoinType {
             JoinType::Leftsemi => "LEFTSEMI",
             JoinType::Leftanti => "LEFTANTI",
             JoinType::Rightsemi => "RIGHTSEMI",
+            JoinType::Rightanti => "RIGHTANTI",
         }
     }
 }
diff --git a/datafusion/proto/src/logical_plan.rs b/datafusion/proto/src/logical_plan.rs
index 1f0e40895..13aaceb79 100644
--- a/datafusion/proto/src/logical_plan.rs
+++ b/datafusion/proto/src/logical_plan.rs
@@ -242,6 +242,7 @@ impl From<protobuf::JoinType> for JoinType {
             protobuf::JoinType::Leftsemi => JoinType::LeftSemi,
             protobuf::JoinType::Rightsemi => JoinType::RightSemi,
             protobuf::JoinType::Leftanti => JoinType::LeftAnti,
+            protobuf::JoinType::Rightanti => JoinType::RightAnti,
         }
     }
 }
@@ -256,6 +257,7 @@ impl From<JoinType> for protobuf::JoinType {
             JoinType::LeftSemi => protobuf::JoinType::Leftsemi,
             JoinType::RightSemi => protobuf::JoinType::Rightsemi,
             JoinType::LeftAnti => protobuf::JoinType::Leftanti,
+            JoinType::RightAnti => protobuf::JoinType::Rightanti,
         }
     }
 }