You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by dh...@apache.org on 2023/06/21 11:13:42 UTC

[arrow-datafusion] branch main updated: Hash Join Vectorized collision checking (#6724)

This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 98669b000e Hash Join Vectorized collision checking (#6724)
98669b000e is described below

commit 98669b000ef5b3ca04e46924760d8925155d6c7d
Author: Daniël Heres <da...@gmail.com>
AuthorDate: Wed Jun 21 13:13:35 2023 +0200

    Hash Join Vectorized collision checking (#6724)
    
    * Initial PoC
    
    * Initial PoC
    
    * Initial PoC
    
    * Vectorized implementation
    
    * Vectorized implementation
    
    * Vectorized implementation
    
    * Vectorized implementation
    
    * Fmt
    
    * Add implementation for null equals null
    
    * WIP
    
    * Cleanup
    
    * Fix
    
    * Rev
    
    * Add feature
    
    * Small tweak
    
    * Add back comment
    
    * More functional implementation
    
    * Remove some ownership / clones
    
    * Clippy
    
    ---------
    
    Co-authored-by: Daniël Heres <da...@coralogix.com>
---
 Cargo.toml                                         |   2 +-
 .../core/src/physical_plan/joins/hash_join.rs      | 169 +++++++++++++--------
 .../src/physical_plan/joins/nested_loop_join.rs    |   8 +-
 .../src/physical_plan/joins/symmetric_hash_join.rs |   8 +-
 datafusion/core/src/physical_plan/joins/utils.rs   |  12 +-
 5 files changed, 121 insertions(+), 78 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index ed6728f35e..7f89e78ad0 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -45,7 +45,7 @@ repository = "https://github.com/apache/arrow-datafusion"
 rust-version = "1.64"
 
 [workspace.dependencies]
-arrow = { version = "42.0.0", features = ["prettyprint"] }
+arrow = { version = "42.0.0", features = ["prettyprint", "dyn_cmp_dict"] }
 arrow-flight = { version = "42.0.0", features = ["flight-sql-experimental"] }
 arrow-buffer = { version = "42.0.0", default-features = false }
 arrow-schema = { version = "42.0.0", default-features = false }
diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs
index 5a0e8f33f5..c89edea4d6 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join.rs
@@ -25,6 +25,8 @@ use arrow::array::{
     StringArray, TimestampNanosecondArray, UInt16Array, UInt32Array, UInt64Array,
     UInt8Array,
 };
+use arrow::buffer::BooleanBuffer;
+use arrow::compute::{and, eq_dyn, is_null, or_kleene, take, FilterBuilder};
 use arrow::datatypes::{ArrowNativeType, DataType};
 use arrow::datatypes::{Schema, SchemaRef};
 use arrow::record_batch::RecordBatch;
@@ -42,6 +44,8 @@ use arrow::{
     },
     util::bit_util,
 };
+use arrow_array::cast::downcast_array;
+use arrow_schema::ArrowError;
 use futures::{ready, Stream, StreamExt, TryStreamExt};
 use std::fmt;
 use std::mem::size_of;
@@ -624,47 +628,6 @@ impl RecordBatchStream for HashJoinStream {
     }
 }
 
-/// Gets build and probe indices which satisfy the on condition (including
-/// the equality condition and the join filter) in the join.
-#[allow(clippy::too_many_arguments)]
-pub fn build_join_indices(
-    probe_batch: &RecordBatch,
-    build_hashmap: &JoinHashMap,
-    build_input_buffer: &RecordBatch,
-    on_build: &[Column],
-    on_probe: &[Column],
-    filter: Option<&JoinFilter>,
-    random_state: &RandomState,
-    null_equals_null: bool,
-    hashes_buffer: &mut Vec<u64>,
-    build_side: JoinSide,
-) -> Result<(UInt64Array, UInt32Array)> {
-    // Get the indices that satisfy the equality condition, like `left.a1 = right.a2`
-    let (build_indices, probe_indices) = build_equal_condition_join_indices(
-        build_hashmap,
-        build_input_buffer,
-        probe_batch,
-        on_build,
-        on_probe,
-        random_state,
-        null_equals_null,
-        hashes_buffer,
-    )?;
-    if let Some(filter) = filter {
-        // Filter the indices which satisfy the non-equal join condition, like `left.b1 = 10`
-        apply_join_filter_to_indices(
-            build_input_buffer,
-            probe_batch,
-            build_indices,
-            probe_indices,
-            filter,
-            build_side,
-        )
-    } else {
-        Ok((build_indices, probe_indices))
-    }
-}
-
 // Returns build/probe indices satisfying the equality condition.
 // On LEFT.b1 = RIGHT.b2
 // LEFT Table:
@@ -706,6 +669,8 @@ pub fn build_equal_condition_join_indices(
     random_state: &RandomState,
     null_equals_null: bool,
     hashes_buffer: &mut Vec<u64>,
+    filter: Option<&JoinFilter>,
+    build_side: JoinSide,
 ) -> Result<(UInt64Array, UInt32Array)> {
     let keys_values = probe_on
         .iter()
@@ -737,17 +702,8 @@ pub fn build_equal_condition_join_indices(
         {
             let mut i = *index - 1;
             loop {
-                // Check hash collisions
-                if equal_rows(
-                    i as usize,
-                    row,
-                    &build_join_values,
-                    &keys_values,
-                    null_equals_null,
-                )? {
-                    build_indices.append(i);
-                    probe_indices.append(row as u32);
-                }
+                build_indices.append(i);
+                probe_indices.append(row as u32);
                 // Follow the chain to get the next index value
                 let next = build_hashmap.next[i as usize];
                 if next == 0 {
@@ -759,10 +715,30 @@ pub fn build_equal_condition_join_indices(
         }
     }
 
-    Ok((
-        PrimitiveArray::new(build_indices.finish().into(), None),
-        PrimitiveArray::new(probe_indices.finish().into(), None),
-    ))
+    let left: UInt64Array = PrimitiveArray::new(build_indices.finish().into(), None);
+    let right: UInt32Array = PrimitiveArray::new(probe_indices.finish().into(), None);
+
+    let (left, right) = if let Some(filter) = filter {
+        // Filter the indices which satisfy the non-equal join condition, like `left.b1 = 10`
+        apply_join_filter_to_indices(
+            build_input_buffer,
+            probe_batch,
+            left,
+            right,
+            filter,
+            build_side,
+        )?
+    } else {
+        (left, right)
+    };
+
+    equal_rows_arr(
+        &left,
+        &right,
+        &build_join_values,
+        &keys_values,
+        null_equals_null,
+    )
 }
 
 macro_rules! equal_rows_elem {
@@ -1097,6 +1073,71 @@ pub fn equal_rows(
     err.unwrap_or(Ok(res))
 }
 
+// version of eq_dyn supporting equality on null arrays
+fn eq_dyn_null(
+    left: &dyn Array,
+    right: &dyn Array,
+    null_equals_null: bool,
+) -> Result<BooleanArray, ArrowError> {
+    match (left.data_type(), right.data_type()) {
+        (DataType::Null, DataType::Null) => Ok(BooleanArray::new(
+            BooleanBuffer::collect_bool(left.len(), |_| null_equals_null),
+            None,
+        )),
+        _ if null_equals_null => {
+            let eq: BooleanArray = eq_dyn(left, right)?;
+
+            let left_is_null = is_null(left)?;
+            let right_is_null = is_null(right)?;
+
+            or_kleene(&and(&left_is_null, &right_is_null)?, &eq)
+        }
+        _ => eq_dyn(left, right),
+    }
+}
+
+pub fn equal_rows_arr(
+    indices_left: &UInt64Array,
+    indices_right: &UInt32Array,
+    left_arrays: &[ArrayRef],
+    right_arrays: &[ArrayRef],
+    null_equals_null: bool,
+) -> Result<(UInt64Array, UInt32Array)> {
+    let mut iter = left_arrays.iter().zip(right_arrays.iter());
+
+    let (first_left, first_right) = iter.next().ok_or_else(|| {
+        DataFusionError::Internal(
+            "At least one array should be provided for both left and right".to_string(),
+        )
+    })?;
+
+    let arr_left = take(first_left.as_ref(), indices_left, None)?;
+    let arr_right = take(first_right.as_ref(), indices_right, None)?;
+
+    let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equals_null)?;
+
+    // Use map and try_fold to iterate over the remaining pairs of arrays.
+    // In each iteration, take is used on the pair of arrays and their equality is determined.
+    // The results are then folded (combined) using the and function to get a final equality result.
+    equal = iter
+        .map(|(left, right)| {
+            let arr_left = take(left.as_ref(), indices_left, None)?;
+            let arr_right = take(right.as_ref(), indices_right, None)?;
+            eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equals_null)
+        })
+        .try_fold(equal, |acc, equal2| and(&acc, &equal2?))?;
+
+    let filter_builder = FilterBuilder::new(&equal).optimize().build();
+
+    let left_filtered = filter_builder.filter(indices_left)?;
+    let right_filtered = filter_builder.filter(indices_right)?;
+
+    Ok((
+        downcast_array(left_filtered.as_ref()),
+        downcast_array(right_filtered.as_ref()),
+    ))
+}
+
 impl HashJoinStream {
     /// Separate implementation function that unpins the [`HashJoinStream`] so
     /// that partial borrows work correctly
@@ -1149,16 +1190,16 @@ impl HashJoinStream {
                     let timer = self.join_metrics.join_time.timer();
 
                     // get the matched two indices for the on condition
-                    let left_right_indices = build_join_indices(
-                        &batch,
+                    let left_right_indices = build_equal_condition_join_indices(
                         &left_data.0,
                         &left_data.1,
+                        &batch,
                         &self.on_left,
                         &self.on_right,
-                        self.filter.as_ref(),
                         &self.random_state,
                         self.null_equals_null,
                         &mut hashes_buffer,
+                        self.filter.as_ref(),
                         JoinSide::Left,
                     );
 
@@ -1184,8 +1225,8 @@ impl HashJoinStream {
                                 &self.schema,
                                 &left_data.1,
                                 &batch,
-                                left_side,
-                                right_side,
+                                &left_side,
+                                &right_side,
                                 &self.column_indices,
                                 JoinSide::Left,
                             );
@@ -1216,8 +1257,8 @@ impl HashJoinStream {
                             &self.schema,
                             &left_data.1,
                             &empty_right_batch,
-                            left_side,
-                            right_side,
+                            &left_side,
+                            &right_side,
                             &self.column_indices,
                             JoinSide::Left,
                         );
@@ -2644,6 +2685,8 @@ mod tests {
             &random_state,
             false,
             &mut vec![0; right.num_rows()],
+            None,
+            JoinSide::Left,
         )?;
 
         let mut left_ids = UInt64Builder::with_capacity(0);
diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
index 8de5c76e51..bb8c190222 100644
--- a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
+++ b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
@@ -482,8 +482,8 @@ impl NestedLoopJoinStream {
                             &self.schema,
                             left_data,
                             &empty_right_batch,
-                            left_side,
-                            right_side,
+                            &left_side,
+                            &right_side,
                             &self.column_indices,
                             JoinSide::Left,
                         );
@@ -611,8 +611,8 @@ fn join_left_and_right_batch(
                 schema,
                 left_batch,
                 right_batch,
-                left_side,
-                right_side,
+                &left_side,
+                &right_side,
                 column_indices,
                 JoinSide::Left,
             )
diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
index 68a8596d66..490f18ccf4 100644
--- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
@@ -1369,8 +1369,8 @@ impl OneSideHashJoiner {
                 schema,
                 &self.input_buffer,
                 probe_batch,
-                build_indices,
-                probe_indices,
+                &build_indices,
+                &probe_indices,
                 column_indices,
                 self.build_side,
             )
@@ -1421,8 +1421,8 @@ impl OneSideHashJoiner {
                 output_schema.as_ref(),
                 &self.input_buffer,
                 &empty_probe_batch,
-                build_indices,
-                probe_indices,
+                &build_indices,
+                &probe_indices,
                 column_indices,
                 self.build_side,
             )
diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs
index f7e81b5add..627bdeebc5 100644
--- a/datafusion/core/src/physical_plan/joins/utils.rs
+++ b/datafusion/core/src/physical_plan/joins/utils.rs
@@ -784,8 +784,8 @@ pub(crate) fn apply_join_filter_to_indices(
         filter.schema(),
         build_input_buffer,
         probe_batch,
-        build_indices.clone(),
-        probe_indices.clone(),
+        &build_indices,
+        &probe_indices,
         filter.column_indices(),
         build_side,
     )?;
@@ -809,8 +809,8 @@ pub(crate) fn build_batch_from_indices(
     schema: &Schema,
     build_input_buffer: &RecordBatch,
     probe_batch: &RecordBatch,
-    build_indices: UInt64Array,
-    probe_indices: UInt32Array,
+    build_indices: &UInt64Array,
+    probe_indices: &UInt32Array,
     column_indices: &[ColumnIndex],
     build_side: JoinSide,
 ) -> Result<RecordBatch> {
@@ -841,7 +841,7 @@ pub(crate) fn build_batch_from_indices(
                 assert_eq!(build_indices.null_count(), build_indices.len());
                 new_null_array(array.data_type(), build_indices.len())
             } else {
-                compute::take(array.as_ref(), &build_indices, None)?
+                compute::take(array.as_ref(), build_indices, None)?
             }
         } else {
             let array = probe_batch.column(column_index.index);
@@ -849,7 +849,7 @@ pub(crate) fn build_batch_from_indices(
                 assert_eq!(probe_indices.null_count(), probe_indices.len());
                 new_null_array(array.data_type(), probe_indices.len())
             } else {
-                compute::take(array.as_ref(), &probe_indices, None)?
+                compute::take(array.as_ref(), probe_indices, None)?
             }
         };
         columns.push(array);