You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/20 18:27:58 UTC
[arrow-datafusion] branch master updated: feat: support nested loop join with the initial version (#4562)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new fddb3d365 feat: support nested loop join with the initial version (#4562)
fddb3d365 is described below
commit fddb3d3651041f41d66a801f10e27387e84374f7
Author: Kun Liu <li...@apache.org>
AuthorDate: Wed Dec 21 02:27:51 2022 +0800
feat: support nested loop join with the initial version (#4562)
* support on condition without equal condition
* refine code for hash join exec
* add method check distribution of right, make code clearly
* fix failed case
* add comments for pub(crate) method
---
.../core/src/physical_plan/joins/hash_join.rs | 232 +-----
datafusion/core/src/physical_plan/joins/mod.rs | 2 +
.../src/physical_plan/joins/nested_loop_join.rs | 893 +++++++++++++++++++++
datafusion/core/src/physical_plan/joins/utils.rs | 247 +++++-
datafusion/core/src/physical_plan/planner.rs | 13 +-
datafusion/core/tests/sql/joins.rs | 4 +-
datafusion/sql/src/planner.rs | 111 ++-
7 files changed, 1239 insertions(+), 263 deletions(-)
diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs
index f72e539d7..f15398018 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join.rs
@@ -26,9 +26,8 @@ use arrow::{
DictionaryArray, LargeStringArray, PrimitiveArray, Time32MillisecondArray,
Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray,
- UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder,
+ UInt32BufferBuilder, UInt64BufferBuilder,
},
- compute,
datatypes::{
Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type,
UInt8Type,
@@ -41,7 +40,7 @@ use std::{time::Instant, vec};
use futures::{ready, Stream, StreamExt, TryStreamExt};
-use arrow::array::{new_null_array, Array};
+use arrow::array::Array;
use arrow::datatypes::{ArrowNativeType, DataType};
use arrow::datatypes::{Schema, SchemaRef};
use arrow::error::{ArrowError, Result as ArrowResult};
@@ -53,7 +52,7 @@ use arrow::array::{
UInt8Array,
};
-use datafusion_common::cast::{as_boolean_array, as_dictionary_array, as_string_array};
+use datafusion_common::cast::{as_dictionary_array, as_string_array};
use hashbrown::raw::RawTable;
@@ -66,7 +65,7 @@ use crate::physical_plan::{
joins::utils::{
adjust_right_output_partitioning, build_join_schema, check_join_is_valid,
combine_join_equivalence_properties, estimate_join_statistics,
- partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn, JoinSide,
+ partitioned_join_output_partitioning, ColumnIndex, JoinFilter, JoinOn,
},
metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet},
DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning,
@@ -84,6 +83,10 @@ use super::{
utils::{OnceAsync, OnceFut},
PartitionMode,
};
+use crate::physical_plan::joins::utils::{
+ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices,
+ get_final_indices_from_bit_map, need_produce_result_in_final,
+};
use log::debug;
use std::fmt;
use std::task::Poll;
@@ -647,50 +650,6 @@ impl RecordBatchStream for HashJoinStream {
}
}
-/// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`.
-/// The resulting batch has [Schema] `schema`.
-fn build_batch_from_indices(
- schema: &Schema,
- left: &RecordBatch,
- right: &RecordBatch,
- left_indices: UInt64Array,
- right_indices: UInt32Array,
- column_indices: &[ColumnIndex],
-) -> ArrowResult<RecordBatch> {
- // build the columns of the new [RecordBatch]:
- // 1. pick whether the column is from the left or right
- // 2. based on the pick, `take` items from the different RecordBatches
- let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
-
- for column_index in column_indices {
- let array = match column_index.side {
- JoinSide::Left => {
- let array = left.column(column_index.index);
- if array.is_empty() || left_indices.null_count() == left_indices.len() {
- // Outer join would generate a null index when finding no match at our side.
- // Therefore, it's possible we are empty but need to populate an n-length null array,
- // where n is the length of the index array.
- assert_eq!(left_indices.null_count(), left_indices.len());
- new_null_array(array.data_type(), left_indices.len())
- } else {
- compute::take(array.as_ref(), &left_indices, None)?
- }
- }
- JoinSide::Right => {
- let array = right.column(column_index.index);
- if array.is_empty() || right_indices.null_count() == right_indices.len() {
- assert_eq!(right_indices.null_count(), right_indices.len());
- new_null_array(array.data_type(), right_indices.len())
- } else {
- compute::take(array.as_ref(), &right_indices, None)?
- }
- }
- };
- columns.push(array);
- }
- RecordBatch::try_new(Arc::new(schema.clone()), columns)
-}
-
// Get left and right indices which is satisfies the on condition (include equal_conditon and filter_in_join) in the Join
#[allow(clippy::too_many_arguments)]
fn build_join_indices(
@@ -821,41 +780,6 @@ fn build_equal_condition_join_indices(
))
}
-fn apply_join_filter_to_indices(
- left: &RecordBatch,
- right: &RecordBatch,
- left_indices: UInt64Array,
- right_indices: UInt32Array,
- filter: &JoinFilter,
-) -> Result<(UInt64Array, UInt32Array)> {
- if left_indices.is_empty() && right_indices.is_empty() {
- return Ok((left_indices, right_indices));
- };
-
- let intermediate_batch = build_batch_from_indices(
- filter.schema(),
- left,
- right,
- PrimitiveArray::from(left_indices.data().clone()),
- PrimitiveArray::from(right_indices.data().clone()),
- filter.column_indices(),
- )?;
- let filter_result = filter
- .expression()
- .evaluate(&intermediate_batch)?
- .into_array(intermediate_batch.num_rows());
- let mask = as_boolean_array(&filter_result)?;
-
- let left_filtered = PrimitiveArray::<UInt64Type>::from(
- compute::filter(&left_indices, mask)?.data().clone(),
- );
- let right_filtered = PrimitiveArray::<UInt32Type>::from(
- compute::filter(&right_indices, mask)?.data().clone(),
- );
-
- Ok((left_filtered, right_filtered))
-}
-
macro_rules! equal_rows_elem {
($array_type:ident, $l: ident, $r: ident, $left: ident, $right: ident, $null_equals_null: ident) => {{
let left_array = $l.as_any().downcast_ref::<$array_type>().unwrap();
@@ -1186,138 +1110,6 @@ fn equal_rows(
err.unwrap_or(Ok(res))
}
-// The input is the matched indices for left and right.
-// Adjust the indices according to the join type
-fn adjust_indices_by_join_type(
- left_indices: UInt64Array,
- right_indices: UInt32Array,
- count_right_batch: usize,
- join_type: JoinType,
-) -> (UInt64Array, UInt32Array) {
- match join_type {
- JoinType::Inner => {
- // matched
- (left_indices, right_indices)
- }
- JoinType::Left => {
- // matched
- (left_indices, right_indices)
- // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap
- }
- JoinType::Right | JoinType::Full => {
- // matched
- // unmatched right row will be produced in this batch
- let right_unmatched_indices =
- get_anti_indices(count_right_batch, &right_indices);
- // combine the matched and unmatched right result together
- append_right_indices(left_indices, right_indices, right_unmatched_indices)
- }
- JoinType::RightSemi => {
- // need to remove the duplicated record in the right side
- let right_indices = get_semi_indices(count_right_batch, &right_indices);
- // the left_indices will not be used later for the `right semi` join
- (left_indices, right_indices)
- }
- JoinType::RightAnti => {
- // need to remove the duplicated record in the right side
- // get the anti index for the right side
- let right_indices = get_anti_indices(count_right_batch, &right_indices);
- // the left_indices will not be used later for the `right anti` join
- (left_indices, right_indices)
- }
- JoinType::LeftSemi | JoinType::LeftAnti => {
- // matched or unmatched left row will be produced in the end of loop
- // TODO: left semi can be optimized.
- // When visit the right batch, we can output the matched left row and don't need to wait the end of loop
- (
- UInt64Array::from_iter_values(vec![]),
- UInt32Array::from_iter_values(vec![]),
- )
- }
- }
-}
-
-fn append_right_indices(
- left_indices: UInt64Array,
- right_indices: UInt32Array,
- appended_right_indices: UInt32Array,
-) -> (UInt64Array, UInt32Array) {
- // left_indices, right_indices and appended_right_indices must not contain the null value
- if appended_right_indices.is_empty() {
- (left_indices, right_indices)
- } else {
- let unmatched_size = appended_right_indices.len();
- // the new left indices: left_indices + null array
- // the new right indices: right_indices + appended_right_indices
- let new_left_indices = left_indices
- .iter()
- .chain(std::iter::repeat(None).take(unmatched_size))
- .collect::<UInt64Array>();
- let new_right_indices = right_indices
- .iter()
- .chain(appended_right_indices.iter())
- .collect::<UInt32Array>();
- (new_left_indices, new_right_indices)
- }
-}
-
-fn get_anti_indices(row_count: usize, input_indices: &UInt32Array) -> UInt32Array {
- let mut bitmap = BooleanBufferBuilder::new(row_count);
- bitmap.append_n(row_count, false);
- input_indices.iter().flatten().for_each(|v| {
- bitmap.set_bit(v as usize, true);
- });
-
- // get the anti index
- (0..row_count)
- .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(idx as u32))
- .collect::<UInt32Array>()
-}
-
-fn get_semi_indices(row_count: usize, input_indices: &UInt32Array) -> UInt32Array {
- let mut bitmap = BooleanBufferBuilder::new(row_count);
- bitmap.append_n(row_count, false);
- input_indices.iter().flatten().for_each(|v| {
- bitmap.set_bit(v as usize, true);
- });
-
- // get the semi index
- (0..row_count)
- .filter_map(|idx| (bitmap.get_bit(idx)).then_some(idx as u32))
- .collect::<UInt32Array>()
-}
-
-fn need_produce_result_in_final(join_type: JoinType) -> bool {
- matches!(
- join_type,
- JoinType::Left | JoinType::LeftAnti | JoinType::LeftSemi | JoinType::Full
- )
-}
-
-fn get_final_indices(
- left_bit_map: &BooleanBufferBuilder,
- join_type: JoinType,
-) -> (UInt64Array, UInt32Array) {
- let left_size = left_bit_map.len();
- let left_indices = if join_type == JoinType::LeftSemi {
- (0..left_size)
- .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64))
- .collect::<UInt64Array>()
- } else {
- // just for `Left`, `LeftAnti` and `Full` join
- // `LeftAnti`, `Left` and `Full` will produce the unmatched left row finally
- (0..left_size)
- .filter_map(|idx| (!left_bit_map.get_bit(idx)).then_some(idx as u64))
- .collect::<UInt64Array>()
- };
- // right_indices
- // all the element in the right side is None
- let mut builder = UInt32Builder::with_capacity(left_indices.len());
- builder.append_nulls(left_indices.len());
- let right_indices = builder.finish();
- (left_indices, right_indices)
-}
-
impl HashJoinStream {
/// Separate implementation function that unpins the [`HashJoinStream`] so
/// that partial borrows work correctly
@@ -1413,8 +1205,10 @@ impl HashJoinStream {
if need_produce_result_in_final(self.join_type) && !self.is_exhausted
{
// use the global left bitmap to produce the left indices and right indices
- let (left_side, right_side) =
- get_final_indices(visited_left_side, self.join_type);
+ let (left_side, right_side) = get_final_indices_from_bit_map(
+ visited_left_side,
+ self.join_type,
+ );
let empty_right_batch =
RecordBatch::new_empty(self.right.schema());
// use the left and right indices to produce the batch result
@@ -1469,12 +1263,14 @@ mod tests {
test::exec::MockExec,
test::{build_table_i32, columns},
};
+ use arrow::array::UInt32Builder;
use arrow::array::UInt64Builder;
use arrow::datatypes::Field;
use arrow::error::ArrowError;
use datafusion_expr::Operator;
use super::*;
+ use crate::physical_plan::joins::utils::JoinSide;
use crate::prelude::SessionContext;
use datafusion_common::ScalarValue;
use datafusion_physical_expr::expressions::Literal;
diff --git a/datafusion/core/src/physical_plan/joins/mod.rs b/datafusion/core/src/physical_plan/joins/mod.rs
index 8066c7d9c..63762ab3c 100644
--- a/datafusion/core/src/physical_plan/joins/mod.rs
+++ b/datafusion/core/src/physical_plan/joins/mod.rs
@@ -19,6 +19,7 @@
mod cross_join;
mod hash_join;
+mod nested_loop_join;
mod sort_merge_join;
pub mod utils;
@@ -36,6 +37,7 @@ pub enum PartitionMode {
pub use cross_join::CrossJoinExec;
pub use hash_join::HashJoinExec;
+pub use nested_loop_join::NestedLoopJoinExec;
// Note: SortMergeJoin is not used in plans yet
pub use sort_merge_join::SortMergeJoinExec;
diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
new file mode 100644
index 000000000..834392512
--- /dev/null
+++ b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
@@ -0,0 +1,893 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines the nested loop join plan, it supports all [`JoinType`].
+//! The nested loop join can execute in parallel by partitions and it is
+//! determined by the [`JoinType`].
+
+use crate::physical_plan::joins::utils::{
+ adjust_indices_by_join_type, adjust_right_output_partitioning,
+ apply_join_filter_to_indices, build_batch_from_indices, build_join_schema,
+ check_join_is_valid, combine_join_equivalence_properties, estimate_join_statistics,
+ get_final_indices_from_bit_map, need_produce_result_in_final, ColumnIndex,
+ JoinFilter, OnceAsync, OnceFut,
+};
+use crate::physical_plan::{
+ DisplayFormatType, Distribution, ExecutionPlan, Partitioning, RecordBatchStream,
+ SendableRecordBatchStream,
+};
+use arrow::array::{
+ BooleanBufferBuilder, UInt32Array, UInt32Builder, UInt64Array, UInt64Builder,
+};
+use arrow::datatypes::{Schema, SchemaRef};
+use arrow::error::{ArrowError, Result as ArrowResult};
+use arrow::record_batch::RecordBatch;
+use datafusion_common::Statistics;
+use datafusion_expr::JoinType;
+use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortExpr};
+use futures::{ready, Stream, StreamExt, TryStreamExt};
+use log::debug;
+use std::any::Any;
+use std::fmt::Formatter;
+use std::sync::Arc;
+use std::task::Poll;
+use std::time::Instant;
+
+use crate::error::Result;
+use crate::execution::context::TaskContext;
+use crate::physical_plan::coalesce_batches::concat_batches;
+
+/// Data of the left side
+type JoinLeftData = RecordBatch;
+
+///
+#[derive(Debug)]
+pub struct NestedLoopJoinExec {
+ /// left side
+ pub(crate) left: Arc<dyn ExecutionPlan>,
+ /// right side
+ pub(crate) right: Arc<dyn ExecutionPlan>,
+ /// Filters which are applied while finding matching rows
+ pub(crate) filter: Option<JoinFilter>,
+ /// How the join is performed
+ pub(crate) join_type: JoinType,
+ /// The schema once the join is applied
+ schema: SchemaRef,
+ /// Build-side data
+ left_fut: OnceAsync<JoinLeftData>,
+ /// Information of index and left / right placement of columns
+ column_indices: Vec<ColumnIndex>,
+}
+
+impl NestedLoopJoinExec {
+ /// Try to create a nwe [`NestedLoopJoinExec`]
+ pub fn try_new(
+ left: Arc<dyn ExecutionPlan>,
+ right: Arc<dyn ExecutionPlan>,
+ filter: Option<JoinFilter>,
+ join_type: &JoinType,
+ ) -> Result<Self> {
+ let left_schema = left.schema();
+ let right_schema = right.schema();
+ check_join_is_valid(&left_schema, &right_schema, &[])?;
+ let (schema, column_indices) =
+ build_join_schema(&left_schema, &right_schema, join_type);
+ Ok(NestedLoopJoinExec {
+ left,
+ right,
+ filter,
+ join_type: *join_type,
+ schema: Arc::new(schema),
+ left_fut: Default::default(),
+ column_indices,
+ })
+ }
+
+ fn is_single_partition_for_left(&self) -> bool {
+ matches!(
+ self.required_input_distribution()[0],
+ Distribution::SinglePartition
+ )
+ }
+
+ fn is_single_partition_for_right(&self) -> bool {
+ matches!(
+ self.required_input_distribution()[1],
+ Distribution::SinglePartition
+ )
+ }
+}
+
+impl ExecutionPlan for NestedLoopJoinExec {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+
+ fn output_partitioning(&self) -> Partitioning {
+ // the partition of output is determined by the rule of `required_input_distribution`
+ // TODO we can replace it by `partitioned_join_output_partitioning`
+ match self.join_type {
+ // use the left partition
+ JoinType::Inner
+ | JoinType::Left
+ | JoinType::LeftSemi
+ | JoinType::LeftAnti
+ | JoinType::Full => self.left.output_partitioning(),
+ // use the right partition
+ JoinType::Right => {
+ // if the partition of right is hash,
+ // and the right partition should be adjusted the column index for the right expr
+ adjust_right_output_partitioning(
+ self.right.output_partitioning(),
+ self.left.schema().fields.len(),
+ )
+ }
+ // use the right partition
+ JoinType::RightSemi | JoinType::RightAnti => self.right.output_partitioning(),
+ }
+ }
+
+ fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
+ // no specified order for the output
+ None
+ }
+
+ fn required_input_distribution(&self) -> Vec<Distribution> {
+ distribution_from_join_type(&self.join_type)
+ }
+
+ fn equivalence_properties(&self) -> EquivalenceProperties {
+ let left_columns_len = self.left.schema().fields.len();
+ combine_join_equivalence_properties(
+ self.join_type,
+ self.left.equivalence_properties(),
+ self.right.equivalence_properties(),
+ left_columns_len,
+ &[], // empty join keys
+ self.schema(),
+ )
+ }
+
+ fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+ vec![self.left.clone(), self.right.clone()]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ children: Vec<Arc<dyn ExecutionPlan>>,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ Ok(Arc::new(NestedLoopJoinExec::try_new(
+ children[0].clone(),
+ children[1].clone(),
+ self.filter.clone(),
+ &self.join_type,
+ )?))
+ }
+
+ fn execute(
+ &self,
+ partition: usize,
+ context: Arc<TaskContext>,
+ ) -> Result<SendableRecordBatchStream> {
+ // left side
+ let left_fut = if self.is_single_partition_for_left() {
+ // if the distribution of left is `SinglePartition`, just need to collect the left one
+ self.left_fut.once(|| {
+ // just one partition for the left side, and the first partition is all of data for left
+ load_left_specified_partition(0, self.left.clone(), context.clone())
+ })
+ } else {
+ // the distribution of left is not single partition, just need the specified partition for left
+ OnceFut::new(load_left_specified_partition(
+ partition,
+ self.left.clone(),
+ context.clone(),
+ ))
+ };
+ // right side
+ let right_side = if self.is_single_partition_for_right() {
+ // the distribution of right is `SinglePartition`
+ // if the distribution of right is `SinglePartition`, just need to collect the right one
+ self.right.execute(0, context)?
+ } else {
+ // the distribution of right is not single partition, just need the specified partition for right
+ self.right.execute(partition, context)?
+ };
+
+ Ok(Box::pin(NestedLoopJoinStream {
+ schema: self.schema.clone(),
+ filter: self.filter.clone(),
+ join_type: self.join_type,
+ left_fut,
+ right: right_side,
+ is_exhausted: false,
+ visited_left_side: None,
+ column_indices: self.column_indices.clone(),
+ }))
+ }
+
+ fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
+ match t {
+ DisplayFormatType::Default => {
+ let display_filter = self.filter.as_ref().map_or_else(
+ || "".to_string(),
+ |f| format!(", filter={:?}", f.expression()),
+ );
+ write!(
+ f,
+ "NestedLoopJoinExec: join_type={:?}{}",
+ self.join_type, display_filter
+ )
+ }
+ }
+ }
+
+ fn statistics(&self) -> Statistics {
+ estimate_join_statistics(
+ self.left.clone(),
+ self.right.clone(),
+ vec![],
+ &self.join_type,
+ )
+ }
+}
+
+// For the nested loop join, different join type need the different distribution for
+// left and right node.
+fn distribution_from_join_type(join_type: &JoinType) -> Vec<Distribution> {
+ match join_type {
+ JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => {
+ // need the left data, and the right should be one partition
+ vec![
+ Distribution::UnspecifiedDistribution,
+ Distribution::SinglePartition,
+ ]
+ }
+ JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
+ // need the right data, and the left should be one partition
+ vec![
+ Distribution::SinglePartition,
+ Distribution::UnspecifiedDistribution,
+ ]
+ }
+ JoinType::Full => {
+ // need the left and right data, and the left and right should be one partition
+ vec![Distribution::SinglePartition, Distribution::SinglePartition]
+ }
+ }
+}
+
+/// Asynchronously collect the result of the left child for the specified partition
+async fn load_left_specified_partition(
+ partition: usize,
+ left: Arc<dyn ExecutionPlan>,
+ context: Arc<TaskContext>,
+) -> Result<JoinLeftData> {
+ let start = Instant::now();
+ let stream = left.execute(partition, context)?;
+
+ // Load all batches and count the rows
+ let (batches, num_rows) = stream
+ .try_fold((Vec::new(), 0usize), |mut acc, batch| async {
+ acc.1 += batch.num_rows();
+ acc.0.push(batch);
+ Ok(acc)
+ })
+ .await?;
+
+ let merged_batch = concat_batches(&left.schema(), &batches, num_rows)?;
+
+ debug!(
+ "Built left-side of nested loop join containing {} rows in {} ms for partition {}",
+ num_rows,
+ start.elapsed().as_millis(),
+ partition
+ );
+
+ Ok(merged_batch)
+}
+
+/// A stream that issues [RecordBatch]es as they arrive from the right of the join.
+struct NestedLoopJoinStream {
+ /// Input schema
+ schema: Arc<Schema>,
+ /// join filter
+ filter: Option<JoinFilter>,
+ /// type of the join
+ join_type: JoinType,
+ /// future for data from left side
+ left_fut: OnceFut<JoinLeftData>,
+ /// right
+ right: SendableRecordBatchStream,
+ /// There is nothing to process anymore and left side is processed in case of left/left semi/left anti/full join
+ is_exhausted: bool,
+ /// Keeps track of the left side rows whether they are visited
+ visited_left_side: Option<BooleanBufferBuilder>,
+ /// Information of index and left / right placement of columns
+ column_indices: Vec<ColumnIndex>,
+ // TODO: support null aware equal
+ // null_equals_null: bool
+}
+
+fn build_join_indices(
+ left_index: usize,
+ batch: &RecordBatch,
+ left_data: &JoinLeftData,
+ filter: Option<&JoinFilter>,
+) -> Result<(UInt64Array, UInt32Array)> {
+ let right_row_count = batch.num_rows();
+ // left indices: [left_index, left_index, ...., left_index]
+ // right indices: [0, 1, 2, 3, 4,....,right_row_count]
+ let left_indices = UInt64Array::from(vec![left_index as u64; right_row_count]);
+ let right_indices = UInt32Array::from_iter_values(0..(right_row_count as u32));
+ // in the nested loop join, the filter can contain non-equal and equal condition.
+ if let Some(filter) = filter {
+ apply_join_filter_to_indices(
+ left_data,
+ batch,
+ left_indices,
+ right_indices,
+ filter,
+ )
+ } else {
+ Ok((left_indices, right_indices))
+ }
+}
+
+impl NestedLoopJoinStream {
+ fn poll_next_impl(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<ArrowResult<RecordBatch>>> {
+ // all left row
+ let left_data = match ready!(self.left_fut.get(cx)) {
+ Ok(left_data) => left_data,
+ Err(e) => return Poll::Ready(Some(Err(e))),
+ };
+
+ let visited_left_side = self.visited_left_side.get_or_insert_with(|| {
+ let left_num_rows = left_data.num_rows();
+ if need_produce_result_in_final(self.join_type) {
+ // these join type need the bitmap to identify which row has be matched or unmatched.
+ // For the `left semi` join, need to use the bitmap to produce the matched row in the left side
+ // For the `left` join, need to use the bitmap to produce the unmatched row in the left side with null
+ // For the `left anti` join, need to use the bitmap to produce the unmatched row in the left side
+ // For the `full` join, need to use the bitmap to produce the unmatched row in the left side with null
+ let mut buffer = BooleanBufferBuilder::new(left_num_rows);
+ buffer.append_n(left_num_rows, false);
+ buffer
+ } else {
+ BooleanBufferBuilder::new(0)
+ }
+ });
+
+ // iter the right batch
+ self.right
+ .poll_next_unpin(cx)
+ .map(|maybe_batch| match maybe_batch {
+ Some(Ok(right_batch)) => {
+ // TODO: optimize this logic like the cross join, and just return a small batch for each loop
+ // get the matched left and right indices
+ // each left row will try to match every right row
+ let indices_result = (0..left_data.num_rows())
+ .map(|left_row_index| {
+ build_join_indices(
+ left_row_index,
+ &right_batch,
+ left_data,
+ self.filter.as_ref(),
+ )
+ })
+ .collect::<Result<Vec<(UInt64Array, UInt32Array)>>>();
+ let mut left_indices_builder = UInt64Builder::new();
+ let mut right_indices_builder = UInt32Builder::new();
+ let left_right_indices = match indices_result {
+ Err(_) => {
+ // TODO why the type of result stream is `Result<T, ArrowError>`, and not the `DataFusionError`
+ Err(ArrowError::ComputeError(
+ "Build left right indices error".to_string(),
+ ))
+ }
+ Ok(indices) => {
+ for (left_side, right_side) in indices {
+ left_indices_builder.append_values(
+ left_side.values(),
+ &vec![true; left_side.len()],
+ );
+ right_indices_builder.append_values(
+ right_side.values(),
+ &vec![true; right_side.len()],
+ );
+ }
+ Ok((
+ left_indices_builder.finish(),
+ right_indices_builder.finish(),
+ ))
+ }
+ };
+ let result = match left_right_indices {
+ Ok((left_side, right_side)) => {
+ // set the left bitmap
+ // and only left, full, left semi, left anti need the left bitmap
+ if need_produce_result_in_final(self.join_type) {
+ left_side.iter().flatten().for_each(|x| {
+ visited_left_side.set_bit(x as usize, true);
+ });
+ }
+ // adjust the two side indices base on the join type
+ let (left_side, right_side) = adjust_indices_by_join_type(
+ left_side,
+ right_side,
+ right_batch.num_rows(),
+ self.join_type,
+ );
+
+ let result = build_batch_from_indices(
+ &self.schema,
+ left_data,
+ &right_batch,
+ left_side,
+ right_side,
+ &self.column_indices,
+ );
+ Some(result)
+ }
+ Err(e) => Some(Err(e)),
+ };
+ result
+ }
+ Some(err) => Some(err),
+ None => {
+ if need_produce_result_in_final(self.join_type) && !self.is_exhausted
+ {
+ // use the global left bitmap to produce the left indices and right indices
+ let (left_side, right_side) = get_final_indices_from_bit_map(
+ visited_left_side,
+ self.join_type,
+ );
+ let empty_right_batch =
+ RecordBatch::new_empty(self.right.schema());
+ // use the left and right indices to produce the batch result
+ let result = build_batch_from_indices(
+ &self.schema,
+ left_data,
+ &empty_right_batch,
+ left_side,
+ right_side,
+ &self.column_indices,
+ );
+ self.is_exhausted = true;
+ Some(result)
+ } else {
+ // end of the join loop
+ None
+ }
+ }
+ })
+ }
+}
+
+impl Stream for NestedLoopJoinStream {
+ type Item = ArrowResult<RecordBatch>;
+
+ fn poll_next(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ self.poll_next_impl(cx)
+ }
+}
+
+impl RecordBatchStream for NestedLoopJoinStream {
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::physical_expr::expressions::BinaryExpr;
+ use crate::{
+ assert_batches_sorted_eq,
+ physical_plan::{
+ common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec,
+ },
+ test::{build_table_i32, columns},
+ };
+ use arrow::datatypes::{DataType, Field};
+ use datafusion_expr::Operator;
+
+ use super::*;
+ use crate::physical_plan::joins::utils::JoinSide;
+ use crate::prelude::SessionContext;
+ use datafusion_common::ScalarValue;
+ use datafusion_physical_expr::expressions::Literal;
+ use datafusion_physical_expr::PhysicalExpr;
+ use std::sync::Arc;
+
+ fn build_table(
+ a: (&str, &Vec<i32>),
+ b: (&str, &Vec<i32>),
+ c: (&str, &Vec<i32>),
+ ) -> Arc<dyn ExecutionPlan> {
+ let batch = build_table_i32(a, b, c);
+ let schema = batch.schema();
+ Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
+ }
+
+ fn build_left_table() -> Arc<dyn ExecutionPlan> {
+ build_table(
+ ("a1", &vec![5, 9, 11]),
+ ("b1", &vec![5, 8, 8]),
+ ("c1", &vec![50, 90, 110]),
+ )
+ }
+
+ fn build_right_table() -> Arc<dyn ExecutionPlan> {
+ build_table(
+ ("a2", &vec![12, 2, 10]),
+ ("b2", &vec![10, 2, 10]),
+ ("c2", &vec![40, 80, 100]),
+ )
+ }
+
+ fn prepare_join_filter() -> JoinFilter {
+ let column_indices = vec![
+ ColumnIndex {
+ index: 1,
+ side: JoinSide::Left,
+ },
+ ColumnIndex {
+ index: 1,
+ side: JoinSide::Right,
+ },
+ ];
+ let intermediate_schema = Schema::new(vec![
+ Field::new("x", DataType::Int32, true),
+ Field::new("x", DataType::Int32, true),
+ ]);
+ // left.b1!=8
+ let left_filter = Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("x", 0)),
+ Operator::NotEq,
+ Arc::new(Literal::new(ScalarValue::Int32(Some(8)))),
+ )) as Arc<dyn PhysicalExpr>;
+ // right.b2!=10
+ let right_filter = Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("x", 1)),
+ Operator::NotEq,
+ Arc::new(Literal::new(ScalarValue::Int32(Some(10)))),
+ )) as Arc<dyn PhysicalExpr>;
+ // filter = left.b1!=8 and right.b2!=10
+ // after filter:
+ // left table:
+ // ("a1", &vec![5]),
+ // ("b1", &vec![5]),
+ // ("c1", &vec![50]),
+ // right table:
+ // ("a2", &vec![12, 2]),
+ // ("b2", &vec![10, 2]),
+ // ("c2", &vec![40, 80]),
+ let filter_expression =
+ Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter))
+ as Arc<dyn PhysicalExpr>;
+
+ JoinFilter::new(filter_expression, column_indices, intermediate_schema)
+ }
+
+ async fn multi_partitioned_join_collect(
+ left: Arc<dyn ExecutionPlan>,
+ right: Arc<dyn ExecutionPlan>,
+ join_type: &JoinType,
+ join_filter: Option<JoinFilter>,
+ context: Arc<TaskContext>,
+ ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
+ let partition_count = 4;
+ let mut output_partition = 1;
+ let distribution = distribution_from_join_type(join_type);
+ // left
+ let left = if matches!(distribution[0], Distribution::SinglePartition) {
+ left
+ } else {
+ output_partition = partition_count;
+ Arc::new(RepartitionExec::try_new(
+ left,
+ Partitioning::RoundRobinBatch(partition_count),
+ )?)
+ } as Arc<dyn ExecutionPlan>;
+
+ let right = if matches!(distribution[1], Distribution::SinglePartition) {
+ right
+ } else {
+ output_partition = partition_count;
+ Arc::new(RepartitionExec::try_new(
+ right,
+ Partitioning::RoundRobinBatch(partition_count),
+ )?)
+ } as Arc<dyn ExecutionPlan>;
+
+ // Use the required distribution for nested loop join to test partition data
+ let nested_loop_join =
+ NestedLoopJoinExec::try_new(left, right, join_filter, join_type)?;
+ let columns = columns(&nested_loop_join.schema());
+ let mut batches = vec![];
+ for i in 0..output_partition {
+ let stream = nested_loop_join.execute(i, context.clone())?;
+ let more_batches = common::collect(stream).await?;
+ batches.extend(
+ more_batches
+ .into_iter()
+ .filter(|b| b.num_rows() > 0)
+ .collect::<Vec<_>>(),
+ );
+ }
+ Ok((columns, batches))
+ }
+
+ #[tokio::test]
+ async fn join_inner_with_filter() -> Result<()> {
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+ let left = build_left_table();
+ let right = build_right_table();
+ let filter = prepare_join_filter();
+ let (columns, batches) = multi_partitioned_join_collect(
+ left,
+ right,
+ &JoinType::Inner,
+ Some(filter),
+ task_ctx,
+ )
+ .await?;
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
+ let expected = vec![
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b2 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 5 | 5 | 50 | 2 | 2 | 80 |",
+ "+----+----+----+----+----+----+",
+ ];
+
+ assert_batches_sorted_eq!(expected, &batches);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_left_with_filter() -> Result<()> {
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+ let left = build_left_table();
+ let right = build_right_table();
+
+ let filter = prepare_join_filter();
+ let (columns, batches) = multi_partitioned_join_collect(
+ left,
+ right,
+ &JoinType::Left,
+ Some(filter),
+ task_ctx,
+ )
+ .await?;
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
+ let expected = vec![
+ "+----+----+-----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b2 | c2 |",
+ "+----+----+-----+----+----+----+",
+ "| 11 | 8 | 110 | | | |",
+ "| 5 | 5 | 50 | 2 | 2 | 80 |",
+ "| 9 | 8 | 90 | | | |",
+ "+----+----+-----+----+----+----+",
+ ];
+
+ assert_batches_sorted_eq!(expected, &batches);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_right_with_filter() -> Result<()> {
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+ let left = build_left_table();
+ let right = build_right_table();
+
+ let filter = prepare_join_filter();
+ let (columns, batches) = multi_partitioned_join_collect(
+ left,
+ right,
+ &JoinType::Right,
+ Some(filter),
+ task_ctx,
+ )
+ .await?;
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
+ let expected = vec![
+ "+----+----+----+----+----+-----+",
+ "| a1 | b1 | c1 | a2 | b2 | c2 |",
+ "+----+----+----+----+----+-----+",
+ "| | | | 10 | 10 | 100 |",
+ "| | | | 12 | 10 | 40 |",
+ "| 5 | 5 | 50 | 2 | 2 | 80 |",
+ "+----+----+----+----+----+-----+",
+ ];
+
+ assert_batches_sorted_eq!(expected, &batches);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_full_with_filter() -> Result<()> {
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+ let left = build_left_table();
+ let right = build_right_table();
+
+ let filter = prepare_join_filter();
+ let (columns, batches) = multi_partitioned_join_collect(
+ left,
+ right,
+ &JoinType::Full,
+ Some(filter),
+ task_ctx,
+ )
+ .await?;
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
+ let expected = vec![
+ "+----+----+-----+----+----+-----+",
+ "| a1 | b1 | c1 | a2 | b2 | c2 |",
+ "+----+----+-----+----+----+-----+",
+ "| | | | 10 | 10 | 100 |",
+ "| | | | 12 | 10 | 40 |",
+ "| 11 | 8 | 110 | | | |",
+ "| 5 | 5 | 50 | 2 | 2 | 80 |",
+ "| 9 | 8 | 90 | | | |",
+ "+----+----+-----+----+----+-----+",
+ ];
+
+ assert_batches_sorted_eq!(expected, &batches);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_left_semi_with_filter() -> Result<()> {
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+ let left = build_left_table();
+ let right = build_right_table();
+
+ let filter = prepare_join_filter();
+ let (columns, batches) = multi_partitioned_join_collect(
+ left,
+ right,
+ &JoinType::LeftSemi,
+ Some(filter),
+ task_ctx,
+ )
+ .await?;
+ assert_eq!(columns, vec!["a1", "b1", "c1"]);
+ let expected = vec![
+ "+----+----+----+",
+ "| a1 | b1 | c1 |",
+ "+----+----+----+",
+ "| 5 | 5 | 50 |",
+ "+----+----+----+",
+ ];
+
+ assert_batches_sorted_eq!(expected, &batches);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_left_anti_with_filter() -> Result<()> {
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+ let left = build_left_table();
+ let right = build_right_table();
+
+ let filter = prepare_join_filter();
+ let (columns, batches) = multi_partitioned_join_collect(
+ left,
+ right,
+ &JoinType::LeftAnti,
+ Some(filter),
+ task_ctx,
+ )
+ .await?;
+ assert_eq!(columns, vec!["a1", "b1", "c1"]);
+ let expected = vec![
+ "+----+----+-----+",
+ "| a1 | b1 | c1 |",
+ "+----+----+-----+",
+ "| 11 | 8 | 110 |",
+ "| 9 | 8 | 90 |",
+ "+----+----+-----+",
+ ];
+
+ assert_batches_sorted_eq!(expected, &batches);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_right_semi_with_filter() -> Result<()> {
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+ let left = build_left_table();
+ let right = build_right_table();
+
+ let filter = prepare_join_filter();
+ let (columns, batches) = multi_partitioned_join_collect(
+ left,
+ right,
+ &JoinType::RightSemi,
+ Some(filter),
+ task_ctx,
+ )
+ .await?;
+ assert_eq!(columns, vec!["a2", "b2", "c2"]);
+ let expected = vec![
+ "+----+----+----+",
+ "| a2 | b2 | c2 |",
+ "+----+----+----+",
+ "| 2 | 2 | 80 |",
+ "+----+----+----+",
+ ];
+
+ assert_batches_sorted_eq!(expected, &batches);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_right_anti_with_filter() -> Result<()> {
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+ let left = build_left_table();
+ let right = build_right_table();
+
+ let filter = prepare_join_filter();
+ let (columns, batches) = multi_partitioned_join_collect(
+ left,
+ right,
+ &JoinType::RightAnti,
+ Some(filter),
+ task_ctx,
+ )
+ .await?;
+ assert_eq!(columns, vec!["a2", "b2", "c2"]);
+ let expected = vec![
+ "+----+----+-----+",
+ "| a2 | b2 | c2 |",
+ "+----+----+-----+",
+ "| 10 | 10 | 100 |",
+ "| 12 | 10 | 40 |",
+ "+----+----+-----+",
+ ];
+
+ assert_batches_sorted_eq!(expected, &batches);
+
+ Ok(())
+ }
+}
diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs
index 936b9343b..11b984694 100644
--- a/datafusion/core/src/physical_plan/joins/utils.rs
+++ b/datafusion/core/src/physical_plan/joins/utils.rs
@@ -21,8 +21,15 @@ use crate::error::{DataFusionError, Result};
use crate::logical_expr::JoinType;
use crate::physical_plan::expressions::Column;
use crate::physical_plan::SchemaRef;
-use arrow::datatypes::{Field, Schema};
-use arrow::error::ArrowError;
+use arrow::array::{
+ new_null_array, Array, BooleanBufferBuilder, PrimitiveArray, UInt32Array,
+ UInt32Builder, UInt64Array,
+};
+use arrow::compute;
+use arrow::datatypes::{Field, Schema, UInt32Type, UInt64Type};
+use arrow::error::{ArrowError, Result as ArrowResult};
+use arrow::record_batch::RecordBatch;
+use datafusion_common::cast::as_boolean_array;
use datafusion_common::ScalarValue;
use datafusion_physical_expr::{EquivalentClass, PhysicalExpr};
use futures::future::{BoxFuture, Shared};
@@ -693,6 +700,242 @@ impl<T: 'static> OnceFut<T> {
}
}
+/// Some type `join_type` of join need to maintain the matched indices bit map for the left side, and
+/// use the bit map to generate the part of result of the join.
+///
+/// For example of the `Left` join, in each iteration of right side, can get the matched result, but need
+/// to maintain the matched indices bit map to get the unmatched row for the left side.
+pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool {
+ matches!(
+ join_type,
+ JoinType::Left | JoinType::LeftAnti | JoinType::LeftSemi | JoinType::Full
+ )
+}
+
+/// In the end of join execution, need to use bit map of the matched indices to generate the final left and
+/// right indices.
+///
+/// For example:
+/// left_bit_map: [true, false, true, true, false]
+/// join_type: `Left`
+///
+/// The result is: ([1,4], [null, null])
+pub(crate) fn get_final_indices_from_bit_map(
+ left_bit_map: &BooleanBufferBuilder,
+ join_type: JoinType,
+) -> (UInt64Array, UInt32Array) {
+ let left_size = left_bit_map.len();
+ let left_indices = if join_type == JoinType::LeftSemi {
+ (0..left_size)
+ .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64))
+ .collect::<UInt64Array>()
+ } else {
+ // just for `Left`, `LeftAnti` and `Full` join
+ // `LeftAnti`, `Left` and `Full` will produce the unmatched left row finally
+ (0..left_size)
+ .filter_map(|idx| (!left_bit_map.get_bit(idx)).then_some(idx as u64))
+ .collect::<UInt64Array>()
+ };
+ // right_indices
+ // all the element in the right side is None
+ let mut builder = UInt32Builder::with_capacity(left_indices.len());
+ builder.append_nulls(left_indices.len());
+ let right_indices = builder.finish();
+ (left_indices, right_indices)
+}
+
+/// Use the `left_indices` and `right_indices` to restructure tuples, and apply the `filter` to
+/// all of them to get the matched left and right indices.
+pub(crate) fn apply_join_filter_to_indices(
+ left: &RecordBatch,
+ right: &RecordBatch,
+ left_indices: UInt64Array,
+ right_indices: UInt32Array,
+ filter: &JoinFilter,
+) -> Result<(UInt64Array, UInt32Array)> {
+ if left_indices.is_empty() && right_indices.is_empty() {
+ return Ok((left_indices, right_indices));
+ };
+
+ let intermediate_batch = build_batch_from_indices(
+ filter.schema(),
+ left,
+ right,
+ PrimitiveArray::from(left_indices.data().clone()),
+ PrimitiveArray::from(right_indices.data().clone()),
+ filter.column_indices(),
+ )?;
+ let filter_result = filter
+ .expression()
+ .evaluate(&intermediate_batch)?
+ .into_array(intermediate_batch.num_rows());
+ let mask = as_boolean_array(&filter_result)?;
+
+ let left_filtered = PrimitiveArray::<UInt64Type>::from(
+ compute::filter(&left_indices, mask)?.data().clone(),
+ );
+ let right_filtered = PrimitiveArray::<UInt32Type>::from(
+ compute::filter(&right_indices, mask)?.data().clone(),
+ );
+
+ Ok((left_filtered, right_filtered))
+}
+
+/// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`.
+/// The resulting batch has [Schema] `schema`.
+pub(crate) fn build_batch_from_indices(
+ schema: &Schema,
+ left: &RecordBatch,
+ right: &RecordBatch,
+ left_indices: UInt64Array,
+ right_indices: UInt32Array,
+ column_indices: &[ColumnIndex],
+) -> ArrowResult<RecordBatch> {
+ // build the columns of the new [RecordBatch]:
+ // 1. pick whether the column is from the left or right
+ // 2. based on the pick, `take` items from the different RecordBatches
+ let mut columns: Vec<Arc<dyn Array>> = Vec::with_capacity(schema.fields().len());
+
+ for column_index in column_indices {
+ let array = match column_index.side {
+ JoinSide::Left => {
+ let array = left.column(column_index.index);
+ if array.is_empty() || left_indices.null_count() == left_indices.len() {
+ // Outer join would generate a null index when finding no match at our side.
+ // Therefore, it's possible we are empty but need to populate an n-length null array,
+ // where n is the length of the index array.
+ assert_eq!(left_indices.null_count(), left_indices.len());
+ new_null_array(array.data_type(), left_indices.len())
+ } else {
+ compute::take(array.as_ref(), &left_indices, None)?
+ }
+ }
+ JoinSide::Right => {
+ let array = right.column(column_index.index);
+ if array.is_empty() || right_indices.null_count() == right_indices.len() {
+ assert_eq!(right_indices.null_count(), right_indices.len());
+ new_null_array(array.data_type(), right_indices.len())
+ } else {
+ compute::take(array.as_ref(), &right_indices, None)?
+ }
+ }
+ };
+ columns.push(array);
+ }
+ RecordBatch::try_new(Arc::new(schema.clone()), columns)
+}
+
+/// The input is the matched indices for left and right and
+/// adjust the indices according to the join type
+pub(crate) fn adjust_indices_by_join_type(
+ left_indices: UInt64Array,
+ right_indices: UInt32Array,
+ count_right_batch: usize,
+ join_type: JoinType,
+) -> (UInt64Array, UInt32Array) {
+ match join_type {
+ JoinType::Inner => {
+ // matched
+ (left_indices, right_indices)
+ }
+ JoinType::Left => {
+ // matched
+ (left_indices, right_indices)
+ // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap
+ }
+ JoinType::Right | JoinType::Full => {
+ // matched
+ // unmatched right row will be produced in this batch
+ let right_unmatched_indices =
+ get_anti_indices(count_right_batch, &right_indices);
+ // combine the matched and unmatched right result together
+ append_right_indices(left_indices, right_indices, right_unmatched_indices)
+ }
+ JoinType::RightSemi => {
+ // need to remove the duplicated record in the right side
+ let right_indices = get_semi_indices(count_right_batch, &right_indices);
+ // the left_indices will not be used later for the `right semi` join
+ (left_indices, right_indices)
+ }
+ JoinType::RightAnti => {
+ // need to remove the duplicated record in the right side
+ // get the anti index for the right side
+ let right_indices = get_anti_indices(count_right_batch, &right_indices);
+ // the left_indices will not be used later for the `right anti` join
+ (left_indices, right_indices)
+ }
+ JoinType::LeftSemi | JoinType::LeftAnti => {
+ // matched or unmatched left row will be produced in the end of loop
+ // When visit the right batch, we can output the matched left row and don't need to wait the end of loop
+ (
+ UInt64Array::from_iter_values(vec![]),
+ UInt32Array::from_iter_values(vec![]),
+ )
+ }
+ }
+}
+
+/// Appends the `right_unmatched_indices` to the `right_indices`,
+/// and fills Null to tail of `left_indices` to
+/// keep the length of `right_indices` and `left_indices` consistent.
+pub(crate) fn append_right_indices(
+ left_indices: UInt64Array,
+ right_indices: UInt32Array,
+ right_unmatched_indices: UInt32Array,
+) -> (UInt64Array, UInt32Array) {
+ // left_indices, right_indices and right_unmatched_indices must not contain the null value
+ if right_unmatched_indices.is_empty() {
+ (left_indices, right_indices)
+ } else {
+ let unmatched_size = right_unmatched_indices.len();
+ // the new left indices: left_indices + null array
+ // the new right indices: right_indices + right_unmatched_indices
+ let new_left_indices = left_indices
+ .iter()
+ .chain(std::iter::repeat(None).take(unmatched_size))
+ .collect::<UInt64Array>();
+ let new_right_indices = right_indices
+ .iter()
+ .chain(right_unmatched_indices.iter())
+ .collect::<UInt32Array>();
+ (new_left_indices, new_right_indices)
+ }
+}
+
+/// Get unmatched and deduplicated indices
+pub(crate) fn get_anti_indices(
+ row_count: usize,
+ input_indices: &UInt32Array,
+) -> UInt32Array {
+ let mut bitmap = BooleanBufferBuilder::new(row_count);
+ bitmap.append_n(row_count, false);
+ input_indices.iter().flatten().for_each(|v| {
+ bitmap.set_bit(v as usize, true);
+ });
+
+ // get the anti index
+ (0..row_count)
+ .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(idx as u32))
+ .collect::<UInt32Array>()
+}
+
+/// Get matched and deduplicated indices
+pub(crate) fn get_semi_indices(
+ row_count: usize,
+ input_indices: &UInt32Array,
+) -> UInt32Array {
+ let mut bitmap = BooleanBufferBuilder::new(row_count);
+ bitmap.append_n(row_count, false);
+ input_indices.iter().flatten().for_each(|v| {
+ bitmap.set_bit(v as usize, true);
+ });
+
+ // get the semi index
+ (0..row_count)
+ .filter_map(|idx| (bitmap.get_bit(idx)).then_some(idx as u32))
+ .collect::<UInt32Array>()
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 5d6f8c99c..64ce744fa 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -43,9 +43,9 @@ use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGro
use crate::physical_plan::explain::ExplainExec;
use crate::physical_plan::expressions::{Column, PhysicalSortExpr};
use crate::physical_plan::filter::FilterExec;
-use crate::physical_plan::joins::CrossJoinExec;
use crate::physical_plan::joins::HashJoinExec;
use crate::physical_plan::joins::SortMergeJoinExec;
+use crate::physical_plan::joins::{CrossJoinExec, NestedLoopJoinExec};
use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::repartition::RepartitionExec;
@@ -999,7 +999,16 @@ impl DefaultPhysicalPlanner {
.read()
.get_bool(OPT_PREFER_HASH_JOIN)
.unwrap_or_default();
- if session_state.config.target_partitions() > 1
+ if join_on.is_empty() {
+ // there is no equal join condition, use the nested loop join
+ // TODO optimize the plan, and use the config of `target_partitions` and `repartition_joins`
+ Ok(Arc::new(NestedLoopJoinExec::try_new(
+ physical_left,
+ physical_right,
+ join_filter,
+ join_type,
+ )?))
+ } else if session_state.config.target_partitions() > 1
&& session_state.config.repartition_joins()
&& !prefer_hash_join
{
diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index e3f92cbb3..70b781399 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -2279,9 +2279,7 @@ async fn right_semi_join() -> Result<()> {
}
#[tokio::test]
-#[ignore = "Test ignored, will be enabled after fixing cross join bug"]
-// https://github.com/apache/arrow-datafusion/issues/4363
-async fn error_cross_join() -> Result<()> {
+async fn left_join_with_nonequal_condition() -> Result<()> {
let test_repartition_joins = vec![true, false];
for repartition_joins in test_repartition_joins {
let ctx = create_join_context("t1_id", "t2_id", repartition_joins).unwrap();
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 6761eab48..dc43cbaf1 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -843,9 +843,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let join_filter = filter.into_iter().reduce(Expr::and);
- if left_keys.is_empty() {
- // TODO should not use cross join when the join_filter exists
- // https://github.com/apache/arrow-datafusion/issues/4363
+ if left_keys.is_empty() && join_filter.is_none() {
let mut join = LogicalPlanBuilder::from(left).cross_join(right)?;
if let Some(filter) = join_filter {
join = join.filter(filter)?;
@@ -1171,10 +1169,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
)?
} else {
match having_expr_opt {
- Some(having_expr) => return Err(DataFusionError::Plan(
- format!("HAVING clause references: {} must appear in the GROUP BY clause or be used in an aggregate function", having_expr))),
- None => (plan, select_exprs, having_expr_opt)
- }
+ Some(having_expr) => return Err(DataFusionError::Plan(
+ format!("HAVING clause references: {} must appear in the GROUP BY clause or be used in an aggregate function", having_expr))),
+ None => (plan, select_exprs, having_expr_opt)
+ }
};
let plan = if let Some(having_expr_post_aggr) = having_expr_post_aggr {
@@ -1846,7 +1844,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
return Err(DataFusionError::Internal(format!(
"Invalid placeholder, not a number: {}",
param
- )))
+ )));
}
};
// Check if the placeholder is in the parameter list
@@ -1966,8 +1964,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
r @ OwnedTableReference::Bare { .. } |
r @ OwnedTableReference::Full { .. } => {
return Err(DataFusionError::Plan(format!(
- "Unsupported compound identifier '{:?}'", r,
- )))
+ "Unsupported compound identifier '{:?}'", r,
+ )));
}
};
@@ -2166,7 +2164,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
negated,
Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?),
Box::new(pattern),
- escape_char
+ escape_char,
)))
}
@@ -2389,13 +2387,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}
- SQLExpr::Floor{expr, field: _field} => {
+ SQLExpr::Floor { expr, field: _field } => {
let fun = BuiltinScalarFunction::Floor;
let args = vec![self.sql_expr_to_logical_expr(*expr, schema, planner_context)?];
Ok(Expr::ScalarFunction { fun, args })
}
- SQLExpr::Ceil{expr, field: _field} => {
+ SQLExpr::Ceil { expr, field: _field } => {
let fun = BuiltinScalarFunction::Ceil;
let args = vec![self.sql_expr_to_logical_expr(*expr, schema, planner_context)?];
Ok(Expr::ScalarFunction { fun, args })
@@ -2698,7 +2696,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
return Err(DataFusionError::Plan(format!(
"Unsupported Value {}",
value[0]
- )))
+ )));
}
},
// for capture signed number e.g. +8, -8
@@ -2709,14 +2707,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
return Err(DataFusionError::Plan(format!(
"Unsupported Value {}",
value[0]
- )))
+ )));
}
},
_ => {
return Err(DataFusionError::Plan(format!(
"Unsupported Value {}",
value[0]
- )))
+ )));
}
};
@@ -2948,7 +2946,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
return Err(DataFusionError::Internal(format!(
"Incorrect data type for time_zone: {}",
v.get_datatype(),
- )))
+ )));
}
None => return Err(DataFusionError::Internal(
"Config Option datafusion.execution.time_zone doesn't exist"
@@ -2976,7 +2974,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}
SQLDataType::Numeric(exact_number_info)
- |SQLDataType::Decimal(exact_number_info) => {
+ | SQLDataType::Decimal(exact_number_info) => {
let (precision, scale) = match *exact_number_info {
ExactNumberInfo::None => (None, None),
ExactNumberInfo::Precision(precision) => (Some(precision), None),
@@ -3008,12 +3006,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
| SQLDataType::CharacterVarying(_)
| SQLDataType::CharVarying(_)
| SQLDataType::CharacterLargeObject(_)
- | SQLDataType::CharLargeObject(_)
+ | SQLDataType::CharLargeObject(_)
// precision is not supported
- | SQLDataType::Timestamp(Some(_), _)
+ | SQLDataType::Timestamp(Some(_), _)
// precision is not supported
- | SQLDataType::Time(Some(_), _)
- | SQLDataType::Dec(_)
+ | SQLDataType::Time(Some(_), _)
+ | SQLDataType::Dec(_)
| SQLDataType::Clob(_) => Err(DataFusionError::NotImplemented(format!(
"Unsupported SQL type {:?}",
sql_type
@@ -4037,7 +4035,7 @@ mod tests {
"Projection: col1, col2\
\n Projection: t.column1 AS col1, t.column2 AS col2\
\n SubqueryAlias: t\
- \n Values: (CAST(Utf8(\"2021-06-10 17:01:00Z\") AS Timestamp(Nanosecond, None)), CAST(Utf8(\"2004-04-09\") AS Date32))"
+ \n Values: (CAST(Utf8(\"2021-06-10 17:01:00Z\") AS Timestamp(Nanosecond, None)), CAST(Utf8(\"2004-04-09\") AS Date32))",
);
}
@@ -4285,7 +4283,7 @@ mod tests {
let sql = "SELECT age, MIN(first_name) FROM person GROUP BY age + 1";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!("Plan(\"Projection references non-aggregate values: Expression person.age could not be resolved from available columns: person.age + Int64(1), MIN(person.first_name)\")",
- format!("{:?}", err)
+ format!("{:?}", err)
);
}
@@ -5845,10 +5843,9 @@ mod tests {
FROM person \
JOIN orders ON id = customer_id OR person.age > 30";
let expected = "Projection: person.id, orders.order_id\
- \n Filter: person.id = orders.customer_id OR person.age > Int64(30)\
- \n CrossJoin:\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Inner Join: Filter: person.id = orders.customer_id OR person.age > Int64(30)\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -5970,10 +5967,9 @@ mod tests {
ON person.id = 10";
let expected = "Projection: person.id, orders.order_id\
- \n Filter: person.id = Int64(10)\
- \n CrossJoin:\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Inner Join: Filter: person.id = Int64(10)\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6048,6 +6044,46 @@ mod tests {
quick_test(sql, expected);
}
+ #[test]
+ fn test_noneq_with_filter_join() {
+ // inner join
+ let sql = "SELECT person.id, person.first_name \
+ FROM person INNER JOIN orders \
+ ON person.age > 10";
+ let expected = "Projection: person.id, person.first_name\
+ \n Inner Join: Filter: person.age > Int64(10)\
+ \n TableScan: person\
+ \n TableScan: orders";
+ quick_test(sql, expected);
+ // left join
+ let sql = "SELECT person.id, person.first_name \
+ FROM person LEFT JOIN orders \
+ ON person.age > 10";
+ let expected = "Projection: person.id, person.first_name\
+ \n Left Join: Filter: person.age > Int64(10)\
+ \n TableScan: person\
+ \n TableScan: orders";
+ quick_test(sql, expected);
+ // right join
+ let sql = "SELECT person.id, person.first_name \
+ FROM person RIGHT JOIN orders \
+ ON person.age > 10";
+ let expected = "Projection: person.id, person.first_name\
+ \n Right Join: Filter: person.age > Int64(10)\
+ \n TableScan: person\
+ \n TableScan: orders";
+ quick_test(sql, expected);
+ // full join
+ let sql = "SELECT person.id, person.first_name \
+ FROM person FULL JOIN orders \
+ ON person.age > 10";
+ let expected = "Projection: person.id, person.first_name\
+ \n Full Join: Filter: person.age > Int64(10)\
+ \n TableScan: person\
+ \n TableScan: orders";
+ quick_test(sql, expected);
+ }
+
#[test]
fn test_one_side_constant_full_join() {
// TODO: this sql should be parsed as join after
@@ -6058,10 +6094,9 @@ mod tests {
ON person.id = 10";
let expected = "Projection: person.id, orders.order_id\
- \n Filter: person.id = Int64(10)\
- \n CrossJoin:\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Full Join: Filter: person.id = Int64(10)\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6432,7 +6467,7 @@ mod tests {
ScalarValue::Utf8(Some("xyz".to_string())),
];
let expected_plan =
- "Projection: person.id, person.age, Utf8(\"xyz\")\
+ "Projection: person.id, person.age, Utf8(\"xyz\")\
\n Filter: person.age IN ([Int32(10), Int32(20)]) AND person.salary > Float64(100) AND person.salary < Float64(200) OR person.first_name < Utf8(\"abc\")\
\n TableScan: person";
@@ -6469,7 +6504,7 @@ mod tests {
ScalarValue::Float64(Some(300.0)),
];
let expected_plan =
- "Projection: person.id, SUM(person.age)\
+ "Projection: person.id, SUM(person.age)\
\n Filter: SUM(person.age) < Int32(10) AND SUM(person.age) > Int64(10) OR SUM(person.age) IN ([Float64(200), Float64(300)])\
\n Aggregate: groupBy=[[person.id]], aggr=[[SUM(person.age)]]\
\n Filter: person.salary > Float64(100)\