You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/06/02 22:14:52 UTC

[arrow-datafusion] branch master updated: Support semi join (#470)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 33ff660  Support semi join (#470)
33ff660 is described below

commit 33ff660318bf60d0c9aa45ceba2c4c943bfe9438
Author: Daniƫl Heres <da...@gmail.com>
AuthorDate: Thu Jun 3 00:14:41 2021 +0200

    Support semi join (#470)
    
    * Support semi join
    
    * Fmt
    
    * Match on Semi
    
    * Simplify
    
    * Fmt
    
    * Undo match
    
    * Update datafusion/src/physical_plan/hash_join.rs
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
    
    * Add item on the left for semi join
    
    * Simplify pattern match
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
 ballista/rust/core/proto/ballista.proto            |  1 +
 .../rust/core/src/serde/logical_plan/from_proto.rs |  1 +
 .../rust/core/src/serde/logical_plan/to_proto.rs   |  1 +
 .../core/src/serde/physical_plan/from_proto.rs     |  1 +
 .../rust/core/src/serde/physical_plan/to_proto.rs  |  1 +
 datafusion/src/logical_plan/builder.rs             |  4 +
 datafusion/src/logical_plan/plan.rs                |  4 +-
 datafusion/src/optimizer/hash_build_probe_order.rs | 10 ++-
 datafusion/src/physical_plan/hash_join.rs          | 89 ++++++++++++++++++----
 datafusion/src/physical_plan/hash_utils.rs         |  3 +
 datafusion/src/physical_plan/planner.rs            |  1 +
 11 files changed, 99 insertions(+), 17 deletions(-)

diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto
index da0c615..0387214 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -363,6 +363,7 @@ enum JoinType {
   LEFT = 1;
   RIGHT = 2;
   FULL = 3;
+  SEMI = 4;
 }
 
 message JoinNode {
diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
index 10c4670..4847126 100644
--- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
@@ -265,6 +265,7 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
                     protobuf::JoinType::Left => JoinType::Left,
                     protobuf::JoinType::Right => JoinType::Right,
                     protobuf::JoinType::Full => JoinType::Full,
+                    protobuf::JoinType::Semi => JoinType::Semi,
                 };
                 LogicalPlanBuilder::from(&convert_box_required!(join.left)?)
                     .join(
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index b630dfc..e1c0c5e 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -834,6 +834,7 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
                     JoinType::Left => protobuf::JoinType::Left,
                     JoinType::Right => protobuf::JoinType::Right,
                     JoinType::Full => protobuf::JoinType::Full,
+                    JoinType::Semi => protobuf::JoinType::Semi,
                 };
                 let left_join_column = on.iter().map(|on| on.0.to_owned()).collect();
                 let right_join_column = on.iter().map(|on| on.1.to_owned()).collect();
diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
index 2039def..7f98a83 100644
--- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
@@ -379,6 +379,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
                     protobuf::JoinType::Left => JoinType::Left,
                     protobuf::JoinType::Right => JoinType::Right,
                     protobuf::JoinType::Full => JoinType::Full,
+                    protobuf::JoinType::Semi => JoinType::Semi,
                 };
                 Ok(Arc::new(HashJoinExec::try_new(
                     left,
diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
index 9571f3d..c409f94 100644
--- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
@@ -133,6 +133,7 @@ impl TryInto<protobuf::PhysicalPlanNode> for Arc<dyn ExecutionPlan> {
                 JoinType::Left => protobuf::JoinType::Left,
                 JoinType::Right => protobuf::JoinType::Right,
                 JoinType::Full => protobuf::JoinType::Full,
+                JoinType::Semi => protobuf::JoinType::Semi,
             };
             Ok(protobuf::PhysicalPlanNode {
                 physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new(
diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs
index c02555d..71de48c 100644
--- a/datafusion/src/logical_plan/builder.rs
+++ b/datafusion/src/logical_plan/builder.rs
@@ -410,6 +410,10 @@ fn build_join_schema(
             // left then right
             left_fields.chain(right_fields).cloned().collect()
         }
+        JoinType::Semi => {
+            // Only use the left side for the schema
+            left.fields().clone()
+        }
         JoinType::Right => {
             // remove left-side join keys if they have the same names as the right-side
             let duplicate_keys = &on
diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs
index 2d85abb..5cb94be 100644
--- a/datafusion/src/logical_plan/plan.rs
+++ b/datafusion/src/logical_plan/plan.rs
@@ -34,7 +34,7 @@ use std::{
 };
 
 /// Join type
-#[derive(Debug, Clone, Copy)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
 pub enum JoinType {
     /// Inner Join
     Inner,
@@ -44,6 +44,8 @@ pub enum JoinType {
     Right,
     /// Full Join
     Full,
+    /// Semi Join
+    Semi,
 }
 
 /// A LogicalPlan represents the different types of relational
diff --git a/datafusion/src/optimizer/hash_build_probe_order.rs b/datafusion/src/optimizer/hash_build_probe_order.rs
index 100ae4f..86d38ef 100644
--- a/datafusion/src/optimizer/hash_build_probe_order.rs
+++ b/datafusion/src/optimizer/hash_build_probe_order.rs
@@ -106,6 +106,13 @@ fn should_swap_join_order(left: &LogicalPlan, right: &LogicalPlan) -> bool {
     }
 }
 
+fn supports_swap(join_type: JoinType) -> bool {
+    match join_type {
+        JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => true,
+        JoinType::Semi => false,
+    }
+}
+
 impl OptimizerRule for HashBuildProbeOrder {
     fn name(&self) -> &str {
         "hash_build_probe_order"
@@ -128,7 +135,7 @@ impl OptimizerRule for HashBuildProbeOrder {
             } => {
                 let left = self.optimize(left, execution_props)?;
                 let right = self.optimize(right, execution_props)?;
-                if should_swap_join_order(&left, &right) {
+                if should_swap_join_order(&left, &right) && supports_swap(*join_type) {
                     // Swap left and right, change join type and (equi-)join key order
                     Ok(LogicalPlan::Join {
                         left: Arc::new(right),
@@ -216,6 +223,7 @@ fn swap_join_type(join_type: JoinType) -> JoinType {
         JoinType::Full => JoinType::Full,
         JoinType::Left => JoinType::Right,
         JoinType::Right => JoinType::Left,
+        _ => unreachable!(),
     }
 }
 
diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs
index 01551cd..6653b9a 100644
--- a/datafusion/src/physical_plan/hash_join.rs
+++ b/datafusion/src/physical_plan/hash_join.rs
@@ -184,7 +184,7 @@ impl HashJoinExec {
     /// Calculates column indices and left/right placement on input / output schemas and jointype
     fn column_indices_from_schema(&self) -> ArrowResult<Vec<ColumnIndex>> {
         let (primary_is_left, primary_schema, secondary_schema) = match self.join_type {
-            JoinType::Inner | JoinType::Left | JoinType::Full => {
+            JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Semi => {
                 (true, self.left.schema(), self.right.schema())
             }
             JoinType::Right => (false, self.right.schema(), self.left.schema()),
@@ -376,7 +376,7 @@ impl ExecutionPlan for HashJoinExec {
         let column_indices = self.column_indices_from_schema()?;
         let num_rows = left_data.1.num_rows();
         let visited_left_side = match self.join_type {
-            JoinType::Left | JoinType::Full => vec![false; num_rows],
+            JoinType::Left | JoinType::Full | JoinType::Semi => vec![false; num_rows],
             JoinType::Inner | JoinType::Right => vec![],
         };
         Ok(Box::pin(HashJoinStream {
@@ -544,6 +544,13 @@ fn build_batch(
     )
     .unwrap();
 
+    if join_type == JoinType::Semi {
+        return Ok((
+            RecordBatch::new_empty(Arc::new(schema.clone())),
+            left_indices,
+        ));
+    }
+
     build_batch_from_indices(
         schema,
         &left_data.1,
@@ -606,7 +613,7 @@ fn build_join_indexes(
     let left = &left_data.0;
 
     match join_type {
-        JoinType::Inner => {
+        JoinType::Inner | JoinType::Semi => {
             // Using a buffer builder to avoid slower normal builder
             let mut left_indices = UInt64BufferBuilder::new(0);
             let mut right_indices = UInt32BufferBuilder::new(0);
@@ -1108,23 +1115,35 @@ pub fn create_hashes<'a>(
     Ok(hashes_buffer)
 }
 
-// Produces a batch for left-side rows that are not marked as being visited during the whole join
-fn produce_unmatched(
+// Produces a batch for left-side rows that have/have not been matched during the whole join
+fn produce_from_matched(
     visited_left_side: &[bool],
     schema: &SchemaRef,
     column_indices: &[ColumnIndex],
     left_data: &JoinLeftData,
+    unmatched: bool,
 ) -> ArrowResult<RecordBatch> {
     // Find indices which didn't match any right row (are false)
-    let unmatched_indices: Vec<u64> = visited_left_side
-        .iter()
-        .enumerate()
-        .filter(|&(_, &value)| !value)
-        .map(|(index, _)| index as u64)
-        .collect();
+    let indices = if unmatched {
+        UInt64Array::from_iter_values(
+            visited_left_side
+                .iter()
+                .enumerate()
+                .filter(|&(_, &value)| !value)
+                .map(|(index, _)| index as u64),
+        )
+    } else {
+        // produce those that did match
+        UInt64Array::from_iter_values(
+            visited_left_side
+                .iter()
+                .enumerate()
+                .filter(|&(_, &value)| value)
+                .map(|(index, _)| index as u64),
+        )
+    };
 
     // generate batches by taking values from the left side and generating columns filled with null on the right side
-    let indices = UInt64Array::from_iter_values(unmatched_indices);
     let num_rows = indices.len();
     let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
     for (idx, column_index) in column_indices.iter().enumerate() {
@@ -1171,7 +1190,7 @@ impl Stream for HashJoinStream {
                         self.num_output_rows += batch.num_rows();
 
                         match self.join_type {
-                            JoinType::Left | JoinType::Full => {
+                            JoinType::Left | JoinType::Full | JoinType::Semi => {
                                 left_side.iter().flatten().for_each(|x| {
                                     self.visited_left_side[x as usize] = true;
                                 });
@@ -1185,12 +1204,15 @@ impl Stream for HashJoinStream {
                     let start = Instant::now();
                     // For the left join, produce rows for unmatched rows
                     match self.join_type {
-                        JoinType::Left | JoinType::Full if !self.is_exhausted => {
-                            let result = produce_unmatched(
+                        JoinType::Left | JoinType::Full | JoinType::Semi
+                            if !self.is_exhausted =>
+                        {
+                            let result = produce_from_matched(
                                 &self.visited_left_side,
                                 &self.schema,
                                 &self.column_indices,
                                 &self.left_data,
+                                self.join_type != JoinType::Semi,
                             );
                             if let Ok(ref batch) = result {
                                 self.num_input_batches += 1;
@@ -1207,6 +1229,7 @@ impl Stream for HashJoinStream {
                         }
                         JoinType::Left
                         | JoinType::Full
+                        | JoinType::Semi
                         | JoinType::Inner
                         | JoinType::Right => {}
                     }
@@ -1667,6 +1690,42 @@ mod tests {
     }
 
     #[tokio::test]
+    async fn join_semi() -> Result<()> {
+        let left = build_table(
+            ("a1", &vec![1, 2, 2, 3]),
+            ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right
+            ("c1", &vec![7, 8, 8, 9]),
+        );
+        let right = build_table(
+            ("a2", &vec![10, 20, 30, 40]),
+            ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right
+            ("c2", &vec![70, 80, 90, 100]),
+        );
+        let on = &[("b1", "b1")];
+
+        let join = join(left, right, on, &JoinType::Semi)?;
+
+        let columns = columns(&join.schema());
+        assert_eq!(columns, vec!["a1", "b1", "c1"]);
+
+        let stream = join.execute(0).await?;
+        let batches = common::collect(stream).await?;
+
+        let expected = vec![
+            "+----+----+----+",
+            "| a1 | b1 | c1 |",
+            "+----+----+----+",
+            "| 1  | 4  | 7  |",
+            "| 2  | 5  | 8  |",
+            "| 2  | 5  | 8  |",
+            "+----+----+----+",
+        ];
+        assert_batches_sorted_eq!(expected, &batches);
+
+        Ok(())
+    }
+
+    #[tokio::test]
     async fn join_right_one() -> Result<()> {
         let left = build_table(
             ("a1", &vec![1, 2, 3]),
diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs
index 7e030af..110319e 100644
--- a/datafusion/src/physical_plan/hash_utils.rs
+++ b/datafusion/src/physical_plan/hash_utils.rs
@@ -32,6 +32,8 @@ pub enum JoinType {
     Right,
     /// Full Join
     Full,
+    /// Semi Join
+    Semi,
 }
 
 /// The on clause of the join, as vector of (left, right) columns.
@@ -130,6 +132,7 @@ pub fn build_join_schema(
             // left then right
             left_fields.chain(right_fields).cloned().collect()
         }
+        JoinType::Semi => left.fields().clone(),
     };
     Schema::new(fields)
 }
diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs
index 7ddfaf8..4971a02 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -367,6 +367,7 @@ impl DefaultPhysicalPlanner {
                     JoinType::Left => hash_utils::JoinType::Left,
                     JoinType::Right => hash_utils::JoinType::Right,
                     JoinType::Full => hash_utils::JoinType::Full,
+                    JoinType::Semi => hash_utils::JoinType::Semi,
                 };
                 if ctx_state.config.concurrency > 1 && ctx_state.config.repartition_joins
                 {