You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by yj...@apache.org on 2022/05/11 01:46:54 UTC

[arrow-datafusion] branch master updated: Optimize MergeJoin by storing joined indices instead of creating small record batches for each match (#2492)

This is an automated email from the ASF dual-hosted git repository.

yjshen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 04a97b62d Optimize MergeJoin by storing joined indices instead of creating small record batches for each match (#2492)
04a97b62d is described below

commit 04a97b62dfe68157d3ac2fb522daeefcb6c96e21
Author: Zhang Li <ri...@gmail.com>
AuthorDate: Wed May 11 09:46:49 2022 +0800

    Optimize MergeJoin by storing joined indices instead of creating small record batches for each match (#2492)
    
    * optimize smj
    
    * fix timer not working problem
    
    * implement smj's fmt_as() and relies_on_input_order()
    
    * add comments
    
    * add join_type checking in freeze methods
    
    Co-authored-by: zhangli20 <zh...@kuaishou.com>
---
 .../core/src/physical_plan/sort_merge_join.rs      | 474 +++++++++++----------
 1 file changed, 241 insertions(+), 233 deletions(-)

diff --git a/datafusion/core/src/physical_plan/sort_merge_join.rs b/datafusion/core/src/physical_plan/sort_merge_join.rs
index c207917b6..e2248a99b 100644
--- a/datafusion/core/src/physical_plan/sort_merge_join.rs
+++ b/datafusion/core/src/physical_plan/sort_merge_join.rs
@@ -22,6 +22,7 @@
 use std::any::Any;
 use std::cmp::Ordering;
 use std::collections::VecDeque;
+use std::fmt::Formatter;
 use std::ops::Range;
 use std::pin::Pin;
 use std::sync::Arc;
@@ -44,8 +45,8 @@ use crate::physical_plan::expressions::PhysicalSortExpr;
 use crate::physical_plan::join_utils::{build_join_schema, check_join_is_valid, JoinOn};
 use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
 use crate::physical_plan::{
-    metrics, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream,
-    Statistics,
+    metrics, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream,
+    SendableRecordBatchStream, Statistics,
 };
 
 /// join execution plan executes partitions in parallel and combines them into a set of
@@ -128,6 +129,10 @@ impl ExecutionPlan for SortMergeJoinExec {
         self.right.output_ordering()
     }
 
+    fn relies_on_input_order(&self) -> bool {
+        true
+    }
+
     fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
         vec![self.left.clone(), self.right.clone()]
     }
@@ -201,6 +206,18 @@ impl ExecutionPlan for SortMergeJoinExec {
         Some(self.metrics.clone_inner())
     }
 
+    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
+        match t {
+            DisplayFormatType::Default => {
+                write!(
+                    f,
+                    "SortMergeJoin: join_type={:?}, on={:?}, schema={:?}",
+                    self.join_type, self.on, &self.schema
+                )
+            }
+        }
+    }
+
     fn statistics(&self) -> Statistics {
         todo!()
     }
@@ -283,6 +300,33 @@ enum BufferedState {
     Exhausted,
 }
 
+struct StreamedBatch {
+    pub batch: RecordBatch,
+    pub idx: usize,
+    pub join_arrays: Vec<ArrayRef>,
+    pub null_joined: Vec<usize>,
+}
+impl StreamedBatch {
+    fn new(batch: RecordBatch, on_column: &[Column]) -> Self {
+        let join_arrays = join_arrays(&batch, on_column);
+        StreamedBatch {
+            batch,
+            idx: 0,
+            join_arrays,
+            null_joined: vec![],
+        }
+    }
+
+    fn new_empty(schema: SchemaRef) -> Self {
+        StreamedBatch {
+            batch: RecordBatch::new_empty(schema),
+            idx: 0,
+            join_arrays: vec![],
+            null_joined: vec![],
+        }
+    }
+}
+
 /// A buffered batch that contains contiguous rows with same join key
 #[derive(Debug)]
 struct BufferedBatch {
@@ -292,6 +336,10 @@ struct BufferedBatch {
     pub range: Range<usize>,
     /// Array refs of the join key
     pub join_arrays: Vec<ArrayRef>,
+    /// Buffered joined index (null joining buffered)
+    pub null_joined: Vec<usize>,
+    /// Buffered joined index (streamed joining buffered)
+    pub pair_joined: (Vec<usize>, Vec<usize>),
 }
 impl BufferedBatch {
     fn new(batch: RecordBatch, range: Range<usize>, on_column: &[Column]) -> Self {
@@ -300,6 +348,8 @@ impl BufferedBatch {
             batch,
             range,
             join_arrays,
+            null_joined: vec![],
+            pair_joined: (vec![], vec![]),
         }
     }
 }
@@ -324,11 +374,7 @@ struct SMJStream {
     /// Buffered data stream
     pub buffered: SendableRecordBatchStream,
     /// Current processing record batch of streamed
-    pub streamed_batch: RecordBatch,
-    /// Current processing streamed join arrays
-    pub streamed_join_arrays: Vec<ArrayRef>,
-    /// Current processing row of streamed
-    pub streamed_idx: usize,
+    pub streamed_batch: StreamedBatch,
     /// Currrent buffered data
     pub buffered_data: BufferedData,
     /// (used in outer join) Is current streamed row joined at least once?
@@ -347,7 +393,7 @@ struct SMJStream {
     pub on_buffered: Vec<Column>,
     /// Staging output array builders
     pub output_record_batches: Vec<RecordBatch>,
-    /// Staging output size
+    /// Staging output size, including output batches and staging joined results
     pub output_size: usize,
     /// Target output batch size
     pub batch_size: usize,
@@ -370,7 +416,9 @@ impl Stream for SMJStream {
         mut self: Pin<&mut Self>,
         cx: &mut Context<'_>,
     ) -> Poll<Option<Self::Item>> {
-        self.join_metrics.join_time.timer();
+        let join_time = self.join_metrics.join_time.clone();
+        let _timer = join_time.timer();
+
         loop {
             match &self.state {
                 SMJState::Init => {
@@ -428,10 +476,7 @@ impl Stream for SMJStream {
                     self.state = SMJState::JoinOutput;
                 }
                 SMJState::JoinOutput => {
-                    let output_indices = self.join_partial()?;
-                    if !output_indices.is_empty() {
-                        self.output_partial(&output_indices)?;
-                    }
+                    self.join_partial()?;
 
                     if self.output_size < self.batch_size {
                         if self.buffered_data.scanning_finished() {
@@ -439,11 +484,16 @@ impl Stream for SMJStream {
                             self.state = SMJState::Init;
                         }
                     } else {
-                        let record_batch = self.output_record_batch_and_reset()?;
-                        return Poll::Ready(Some(Ok(record_batch)));
+                        self.freeze_all()?;
+                        if !self.output_record_batches.is_empty() {
+                            let record_batch = self.output_record_batch_and_reset()?;
+                            return Poll::Ready(Some(Ok(record_batch)));
+                        }
+                        return Poll::Pending;
                     }
                 }
                 SMJState::Exhausted => {
+                    self.freeze_all()?;
                     if !self.output_record_batches.is_empty() {
                         let record_batch = self.output_record_batch_and_reset()?;
                         return Poll::Ready(Some(Ok(record_batch)));
@@ -469,18 +519,18 @@ impl SMJStream {
         batch_size: usize,
         join_metrics: SortMergeJoinMetrics,
     ) -> Result<Self> {
+        let streamed_schema = streamed.schema();
+        let buffered_schema = buffered.schema();
         Ok(Self {
             state: SMJState::Init,
             sort_options,
             null_equals_null,
-            schema: schema.clone(),
-            streamed_schema: streamed.schema(),
-            buffered_schema: buffered.schema(),
+            schema,
+            streamed_schema: streamed_schema.clone(),
+            buffered_schema,
             streamed,
             buffered,
-            streamed_batch: RecordBatch::new_empty(schema),
-            streamed_join_arrays: vec![],
-            streamed_idx: 0,
+            streamed_batch: StreamedBatch::new_empty(streamed_schema),
             buffered_data: BufferedData::default(),
             streamed_joined: false,
             buffered_joined: false,
@@ -502,8 +552,9 @@ impl SMJStream {
         loop {
             match &self.streamed_state {
                 StreamedState::Init => {
-                    if self.streamed_idx + 1 < self.streamed_batch.num_rows() {
-                        self.streamed_idx += 1;
+                    if self.streamed_batch.idx + 1 < self.streamed_batch.batch.num_rows()
+                    {
+                        self.streamed_batch.idx += 1;
                         self.streamed_state = StreamedState::Ready;
                         return Poll::Ready(Some(Ok(())));
                     } else {
@@ -520,12 +571,11 @@ impl SMJStream {
                     }
                     Poll::Ready(Some(batch)) => {
                         if batch.num_rows() > 0 {
+                            self.freeze_dequeuing_streamed()?;
                             self.join_metrics.input_batches.add(1);
                             self.join_metrics.input_rows.add(batch.num_rows());
-                            self.streamed_batch = batch;
-                            self.streamed_join_arrays =
-                                join_arrays(&self.streamed_batch, &self.on_streamed);
-                            self.streamed_idx = 0;
+                            self.streamed_batch =
+                                StreamedBatch::new(batch, &self.on_streamed);
                             self.streamed_state = StreamedState::Ready;
                         }
                     }
@@ -552,6 +602,7 @@ impl SMJStream {
                     while !self.buffered_data.batches.is_empty() {
                         let head_batch = self.buffered_data.head_batch();
                         if head_batch.range.end == head_batch.batch.num_rows() {
+                            self.freeze_dequeuing_buffered()?;
                             self.buffered_data.batches.pop_front();
                         } else {
                             break;
@@ -650,8 +701,8 @@ impl SMJStream {
         }
 
         return compare_join_arrays(
-            &self.streamed_join_arrays,
-            self.streamed_idx,
+            &self.streamed_batch.join_arrays,
+            self.streamed_batch.idx,
             &self.buffered_data.head_batch().join_arrays,
             self.buffered_data.head_batch().range.start,
             &self.sort_options,
@@ -661,7 +712,7 @@ impl SMJStream {
 
     /// Produce join and fill output buffer until reaching target batch size
     /// or the join is finished
-    fn join_partial(&mut self) -> ArrowResult<Vec<OutputIndex>> {
+    fn join_partial(&mut self) -> ArrowResult<()> {
         let mut join_streamed = false;
         let mut join_buffered = false;
 
@@ -696,28 +747,32 @@ impl SMJStream {
         if !join_streamed && !join_buffered {
             // no joined data
             self.buffered_data.scanning_finish();
-            return Ok(vec![]);
+            return Ok(());
         }
 
-        let mut output_indices = vec![];
-
         if join_buffered {
             // joining streamed/nulls and buffered
-            let streamed_idx = if join_streamed {
-                Some(self.streamed_idx)
-            } else {
-                None
-            };
             while !self.buffered_data.scanning_finished()
                 && self.output_size < self.batch_size
             {
-                output_indices.push(OutputIndex {
-                    streamed_idx,
-                    buffered_idx: Some((
-                        self.buffered_data.scanning_batch_idx,
-                        self.buffered_data.scanning_idx(),
-                    )),
-                });
+                let scanning_idx = self.buffered_data.scanning_idx();
+                if join_streamed {
+                    self.buffered_data
+                        .scanning_batch_mut()
+                        .pair_joined
+                        .0
+                        .push(self.streamed_batch.idx);
+                    self.buffered_data
+                        .scanning_batch_mut()
+                        .pair_joined
+                        .1
+                        .push(scanning_idx);
+                } else {
+                    self.buffered_data
+                        .scanning_batch_mut()
+                        .null_joined
+                        .push(scanning_idx);
+                }
                 self.output_size += 1;
                 self.buffered_data.scanning_advance();
 
@@ -728,194 +783,112 @@ impl SMJStream {
             }
         } else {
             // joining streamed and nulls
-            output_indices.push(OutputIndex {
-                streamed_idx: Some(self.streamed_idx),
-                buffered_idx: None,
-            });
+            self.streamed_batch
+                .null_joined
+                .push(self.streamed_batch.idx);
             self.output_size += 1;
             self.buffered_data.scanning_finish();
             self.streamed_joined = true;
         }
-        Ok(output_indices)
+        Ok(())
     }
 
-    fn output_record_batch_and_reset(&mut self) -> ArrowResult<RecordBatch> {
-        assert!(!self.output_record_batches.is_empty());
-
-        let record_batch =
-            combine_batches(&self.output_record_batches, self.schema.clone())?.unwrap();
-        self.join_metrics.output_batches.add(1);
-        self.join_metrics.output_rows.add(record_batch.num_rows());
-        self.output_size = 0;
-        self.output_record_batches.clear();
-        Ok(record_batch)
+    fn freeze_all(&mut self) -> ArrowResult<()> {
+        self.freeze_streamed_join_null()?;
+        self.freeze_buffered_join_null(self.buffered_data.batches.len())?;
+        self.freeze_buffered_join_streamed(self.buffered_data.batches.len())?;
+        Ok(())
     }
 
-    fn output_partial(&mut self, output_indices: &[OutputIndex]) -> ArrowResult<()> {
-        match self.join_type {
-            JoinType::Inner => {
-                self.output_partial_streamed_joining_buffered(output_indices)?;
-            }
-            JoinType::Left | JoinType::Right => {
-                self.output_partial_streamed_joining_buffered(output_indices)?;
-                self.output_partial_streamed_joining_null(output_indices)?;
-            }
-            JoinType::Full => {
-                self.output_partial_streamed_joining_buffered(output_indices)?;
-                self.output_partial_streamed_joining_null(output_indices)?;
-                self.output_partial_null_joining_buffered(output_indices)?;
-            }
-            JoinType::Semi | JoinType::Anti => {
-                self.output_partial_streamed_joining_null(output_indices)?;
-            }
-        }
+    // freeze when a dequeueing streamed batch
+    fn freeze_dequeuing_streamed(&mut self) -> ArrowResult<()> {
+        self.freeze_streamed_join_null()?;
+        self.freeze_buffered_join_streamed(self.buffered_data.batches.len())?;
         Ok(())
     }
 
-    fn output_partial_streamed_joining_buffered(
-        &mut self,
-        output_indices: &[OutputIndex],
-    ) -> ArrowResult<()> {
-        let mut output = |buffered_batch_idx: usize, indices: &[OutputIndex]| {
-            if indices.is_empty() {
-                return ArrowResult::Ok(());
-            }
-
-            // take streamed columns
-            let streamed_indices = UInt64Array::from_iter_values(
-                indices
-                    .iter()
-                    .map(|index| index.streamed_idx.unwrap() as u64),
-            );
-            let mut streamed_columns = self
-                .streamed_batch
-                .columns()
-                .iter()
-                .map(|column| take(column, &streamed_indices, None))
-                .collect::<ArrowResult<Vec<_>>>()?;
-
-            // take buffered columns
-            let buffered_indices = UInt64Array::from_iter_values(
-                indices
-                    .iter()
-                    .map(|index| index.buffered_idx.unwrap().1 as u64),
-            );
-            let mut buffered_columns = self.buffered_data.batches[buffered_batch_idx]
-                .batch
-                .columns()
-                .iter()
-                .map(|column| take(column, &buffered_indices, None))
-                .collect::<ArrowResult<Vec<_>>>()?;
-
-            // combine columns and produce record batch
-            let columns = match self.join_type {
-                JoinType::Inner | JoinType::Left | JoinType::Full => {
-                    streamed_columns.extend(buffered_columns);
-                    streamed_columns
-                }
-                JoinType::Right => {
-                    buffered_columns.extend(streamed_columns);
-                    buffered_columns
-                }
-                JoinType::Semi | JoinType::Anti => {
-                    unreachable!()
-                }
-            };
-            let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?;
-            self.output_record_batches.push(record_batch);
-            Ok(())
-        };
-
-        let mut buffered_batch_idx = 0;
-        let mut indices = vec![];
-        for &index in output_indices
-            .iter()
-            .filter(|index| index.streamed_idx.is_some())
-            .filter(|index| index.buffered_idx.is_some())
-        {
-            let buffered_idx = index.buffered_idx.unwrap();
-            if index.buffered_idx.unwrap().0 != buffered_batch_idx {
-                output(buffered_batch_idx, &indices)?;
-                buffered_batch_idx = buffered_idx.0;
-                indices.clear();
-            }
-            indices.push(index);
-        }
-        output(buffered_batch_idx, &indices)?;
+    // freeze when a dequeueing streamed batch
+    fn freeze_dequeuing_buffered(&mut self) -> ArrowResult<()> {
+        self.freeze_buffered_join_streamed(1)?;
+        self.freeze_buffered_join_null(1)?;
         Ok(())
     }
 
-    fn output_partial_streamed_joining_null(
-        &mut self,
-        output_indices: &[OutputIndex],
-    ) -> ArrowResult<()> {
-        // streamed joining null
+    // join_type must be one of: `Left`/`Right`/`Full`/`Semi`/`Anti`
+    fn freeze_streamed_join_null(&mut self) -> ArrowResult<()> {
+        if !matches!(
+            self.join_type,
+            JoinType::Left
+                | JoinType::Right
+                | JoinType::Full
+                | JoinType::Semi
+                | JoinType::Anti
+        ) {
+            return Ok(());
+        }
         let streamed_indices = UInt64Array::from_iter_values(
-            output_indices
+            self.streamed_batch
+                .null_joined
                 .iter()
-                .filter(|index| index.streamed_idx.is_some())
-                .filter(|index| index.buffered_idx.is_none())
-                .map(|index| index.streamed_idx.unwrap() as u64),
+                .map(|&index| index as u64),
         );
+        if streamed_indices.is_empty() {
+            return Ok(());
+        }
+        self.streamed_batch.null_joined.clear();
+
         let mut streamed_columns = self
             .streamed_batch
+            .batch
             .columns()
             .iter()
             .map(|column| take(column, &streamed_indices, None))
             .collect::<ArrowResult<Vec<_>>>()?;
 
-        let mut buffered_columns = self
-            .buffered_schema
-            .fields()
-            .iter()
-            .map(|f| new_null_array(f.data_type(), streamed_indices.len()))
-            .collect::<Vec<_>>();
+        let columns = if matches!(self.join_type, JoinType::Semi | JoinType::Anti) {
+            streamed_columns
+        } else {
+            let mut buffered_columns = self
+                .buffered_schema
+                .fields()
+                .iter()
+                .map(|f| new_null_array(f.data_type(), streamed_indices.len()))
+                .collect::<Vec<_>>();
 
-        let columns = match self.join_type {
-            JoinType::Inner => {
-                unreachable!()
-            }
-            JoinType::Left | JoinType::Full => {
-                streamed_columns.extend(buffered_columns);
-                streamed_columns
-            }
-            JoinType::Right => {
+            if matches!(self.join_type, JoinType::Right) {
                 buffered_columns.extend(streamed_columns);
                 buffered_columns
+            } else {
+                streamed_columns.extend(buffered_columns);
+                streamed_columns
             }
-            JoinType::Anti | JoinType::Semi => streamed_columns,
         };
-
-        if !streamed_indices.is_empty() {
-            let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?;
-            self.output_record_batches.push(record_batch);
-        }
+        self.output_record_batches
+            .push(RecordBatch::try_new(self.schema.clone(), columns)?);
         Ok(())
     }
 
-    fn output_partial_null_joining_buffered(
-        &mut self,
-        output_indices: &[OutputIndex],
-    ) -> ArrowResult<()> {
-        let mut output = |buffered_batch_idx: usize, indices: &[OutputIndex]| {
-            if indices.is_empty() {
-                return ArrowResult::Ok(());
-            }
-
-            // take buffered columns
+    // join_type must be `Full`
+    fn freeze_buffered_join_null(&mut self, batch_count: usize) -> ArrowResult<()> {
+        if !matches!(self.join_type, JoinType::Full) {
+            return Ok(());
+        }
+        for buffered_batch in self.buffered_data.batches.range_mut(..batch_count) {
             let buffered_indices = UInt64Array::from_iter_values(
-                indices
-                    .iter()
-                    .map(|index| index.buffered_idx.unwrap().1 as u64),
+                buffered_batch.null_joined.iter().map(|&index| index as u64),
             );
-            let buffered_columns = self.buffered_data.batches[buffered_batch_idx]
+            if buffered_indices.is_empty() {
+                continue;
+            }
+            buffered_batch.null_joined.clear();
+
+            let buffered_columns = buffered_batch
                 .batch
                 .columns()
                 .iter()
                 .map(|column| take(column, &buffered_indices, None))
                 .collect::<ArrowResult<Vec<_>>>()?;
 
-            // create null streamed columns
             let mut streamed_columns = self
                 .streamed_schema
                 .fields()
@@ -923,43 +896,82 @@ impl SMJStream {
                 .map(|f| new_null_array(f.data_type(), buffered_indices.len()))
                 .collect::<Vec<_>>();
 
-            // combine columns and produce record batch
-            let columns = match self.join_type {
-                JoinType::Full => {
-                    streamed_columns.extend(buffered_columns);
-                    streamed_columns
-                }
-                JoinType::Inner
-                | JoinType::Left
-                | JoinType::Right
-                | JoinType::Semi
-                | JoinType::Anti => {
-                    unreachable!()
-                }
-            };
-            let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?;
-            self.output_record_batches.push(record_batch);
-            Ok(())
-        };
+            streamed_columns.extend(buffered_columns);
+            let columns = streamed_columns;
 
-        let mut buffered_batch_idx = 0;
-        let mut indices = vec![];
-        for &index in output_indices
-            .iter()
-            .filter(|index| index.streamed_idx.is_none())
-            .filter(|index| index.buffered_idx.is_some())
-        {
-            let buffered_idx = index.buffered_idx.unwrap();
-            if buffered_idx.0 != buffered_batch_idx {
-                output(buffered_batch_idx, &indices)?;
-                buffered_batch_idx = buffered_idx.0;
-                indices.clear();
+            self.output_record_batches
+                .push(RecordBatch::try_new(self.schema.clone(), columns)?);
+        }
+        Ok(())
+    }
+
+    // join_type must be `Inner`/`Left`/`Right`/`Full`
+    fn freeze_buffered_join_streamed(&mut self, batch_count: usize) -> ArrowResult<()> {
+        if !matches!(
+            self.join_type,
+            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
+        ) {
+            return Ok(());
+        }
+        for buffered_batch in self.buffered_data.batches.range_mut(..batch_count) {
+            let buffered_indices = UInt64Array::from_iter_values(
+                buffered_batch
+                    .pair_joined
+                    .1
+                    .iter()
+                    .map(|&index| index as u64),
+            );
+            let streamed_indices = UInt64Array::from_iter_values(
+                buffered_batch
+                    .pair_joined
+                    .0
+                    .iter()
+                    .map(|&index| index as u64),
+            );
+            if buffered_indices.is_empty() {
+                continue;
             }
-            indices.push(index);
+            buffered_batch.pair_joined.0.clear();
+            buffered_batch.pair_joined.1.clear();
+
+            let mut buffered_columns = buffered_batch
+                .batch
+                .columns()
+                .iter()
+                .map(|column| take(column, &buffered_indices, None))
+                .collect::<ArrowResult<Vec<_>>>()?;
+
+            let mut streamed_columns = self
+                .streamed_batch
+                .batch
+                .columns()
+                .iter()
+                .map(|column| take(column, &streamed_indices, None))
+                .collect::<ArrowResult<Vec<_>>>()?;
+
+            let columns = if matches!(self.join_type, JoinType::Right) {
+                buffered_columns.extend(streamed_columns);
+                buffered_columns
+            } else {
+                streamed_columns.extend(buffered_columns);
+                streamed_columns
+            };
+
+            self.output_record_batches
+                .push(RecordBatch::try_new(self.schema.clone(), columns)?);
         }
-        output(buffered_batch_idx, &indices)?;
         Ok(())
     }
+
+    fn output_record_batch_and_reset(&mut self) -> ArrowResult<RecordBatch> {
+        let record_batch =
+            combine_batches(&self.output_record_batches, self.schema.clone())?.unwrap();
+        self.join_metrics.output_batches.add(1);
+        self.join_metrics.output_rows.add(record_batch.num_rows());
+        self.output_size -= record_batch.num_rows();
+        self.output_record_batches.clear();
+        Ok(record_batch)
+    }
 }
 
 /// Buffered data contains all buffered batches with one unique join key
@@ -1006,6 +1018,10 @@ impl BufferedData {
         &self.batches[self.scanning_batch_idx]
     }
 
+    pub fn scanning_batch_mut(&mut self) -> &mut BufferedBatch {
+        &mut self.batches[self.scanning_batch_idx]
+    }
+
     pub fn scanning_idx(&self) -> usize {
         self.scanning_batch().range.start + self.scanning_offset
     }
@@ -1024,14 +1040,6 @@ impl BufferedData {
     }
 }
 
-#[derive(Clone, Copy, Debug)]
-struct OutputIndex {
-    /// joined streamed row index
-    streamed_idx: Option<usize>,
-    /// joined buffered batch index and row index
-    buffered_idx: Option<(usize, usize)>,
-}
-
 /// Get join array refs of given batch and join columns
 fn join_arrays(batch: &RecordBatch, on_column: &[Column]) -> Vec<ArrayRef> {
     on_column