You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by yj...@apache.org on 2022/05/11 01:46:54 UTC
[arrow-datafusion] branch master updated: Optimize MergeJoin by storing joined indices instead of creating small record batches for each match (#2492)
This is an automated email from the ASF dual-hosted git repository.
yjshen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 04a97b62d Optimize MergeJoin by storing joined indices instead of creating small record batches for each match (#2492)
04a97b62d is described below
commit 04a97b62dfe68157d3ac2fb522daeefcb6c96e21
Author: Zhang Li <ri...@gmail.com>
AuthorDate: Wed May 11 09:46:49 2022 +0800
Optimize MergeJoin by storing joined indices instead of creating small record batches for each match (#2492)
* optimize smj
* fix timer not working problem
* implement smj's fmt_as() and relies_on_input_order()
* add comments
* add join_type checking in freeze methods
Co-authored-by: zhangli20 <zh...@kuaishou.com>
---
.../core/src/physical_plan/sort_merge_join.rs | 474 +++++++++++----------
1 file changed, 241 insertions(+), 233 deletions(-)
diff --git a/datafusion/core/src/physical_plan/sort_merge_join.rs b/datafusion/core/src/physical_plan/sort_merge_join.rs
index c207917b6..e2248a99b 100644
--- a/datafusion/core/src/physical_plan/sort_merge_join.rs
+++ b/datafusion/core/src/physical_plan/sort_merge_join.rs
@@ -22,6 +22,7 @@
use std::any::Any;
use std::cmp::Ordering;
use std::collections::VecDeque;
+use std::fmt::Formatter;
use std::ops::Range;
use std::pin::Pin;
use std::sync::Arc;
@@ -44,8 +45,8 @@ use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::physical_plan::join_utils::{build_join_schema, check_join_is_valid, JoinOn};
use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use crate::physical_plan::{
- metrics, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream,
- Statistics,
+ metrics, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream,
+ SendableRecordBatchStream, Statistics,
};
/// join execution plan executes partitions in parallel and combines them into a set of
@@ -128,6 +129,10 @@ impl ExecutionPlan for SortMergeJoinExec {
self.right.output_ordering()
}
+ fn relies_on_input_order(&self) -> bool {
+ true
+ }
+
fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.left.clone(), self.right.clone()]
}
@@ -201,6 +206,18 @@ impl ExecutionPlan for SortMergeJoinExec {
Some(self.metrics.clone_inner())
}
+ fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
+ match t {
+ DisplayFormatType::Default => {
+ write!(
+ f,
+ "SortMergeJoin: join_type={:?}, on={:?}, schema={:?}",
+ self.join_type, self.on, &self.schema
+ )
+ }
+ }
+ }
+
fn statistics(&self) -> Statistics {
todo!()
}
@@ -283,6 +300,33 @@ enum BufferedState {
Exhausted,
}
+struct StreamedBatch {
+ pub batch: RecordBatch,
+ pub idx: usize,
+ pub join_arrays: Vec<ArrayRef>,
+ pub null_joined: Vec<usize>,
+}
+impl StreamedBatch {
+ fn new(batch: RecordBatch, on_column: &[Column]) -> Self {
+ let join_arrays = join_arrays(&batch, on_column);
+ StreamedBatch {
+ batch,
+ idx: 0,
+ join_arrays,
+ null_joined: vec![],
+ }
+ }
+
+ fn new_empty(schema: SchemaRef) -> Self {
+ StreamedBatch {
+ batch: RecordBatch::new_empty(schema),
+ idx: 0,
+ join_arrays: vec![],
+ null_joined: vec![],
+ }
+ }
+}
+
/// A buffered batch that contains contiguous rows with same join key
#[derive(Debug)]
struct BufferedBatch {
@@ -292,6 +336,10 @@ struct BufferedBatch {
pub range: Range<usize>,
/// Array refs of the join key
pub join_arrays: Vec<ArrayRef>,
+ /// Buffered joined index (null joining buffered)
+ pub null_joined: Vec<usize>,
+ /// Buffered joined index (streamed joining buffered)
+ pub pair_joined: (Vec<usize>, Vec<usize>),
}
impl BufferedBatch {
fn new(batch: RecordBatch, range: Range<usize>, on_column: &[Column]) -> Self {
@@ -300,6 +348,8 @@ impl BufferedBatch {
batch,
range,
join_arrays,
+ null_joined: vec![],
+ pair_joined: (vec![], vec![]),
}
}
}
@@ -324,11 +374,7 @@ struct SMJStream {
/// Buffered data stream
pub buffered: SendableRecordBatchStream,
/// Current processing record batch of streamed
- pub streamed_batch: RecordBatch,
- /// Current processing streamed join arrays
- pub streamed_join_arrays: Vec<ArrayRef>,
- /// Current processing row of streamed
- pub streamed_idx: usize,
+ pub streamed_batch: StreamedBatch,
/// Currrent buffered data
pub buffered_data: BufferedData,
/// (used in outer join) Is current streamed row joined at least once?
@@ -347,7 +393,7 @@ struct SMJStream {
pub on_buffered: Vec<Column>,
/// Staging output array builders
pub output_record_batches: Vec<RecordBatch>,
- /// Staging output size
+ /// Staging output size, including output batches and staging joined results
pub output_size: usize,
/// Target output batch size
pub batch_size: usize,
@@ -370,7 +416,9 @@ impl Stream for SMJStream {
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
- self.join_metrics.join_time.timer();
+ let join_time = self.join_metrics.join_time.clone();
+ let _timer = join_time.timer();
+
loop {
match &self.state {
SMJState::Init => {
@@ -428,10 +476,7 @@ impl Stream for SMJStream {
self.state = SMJState::JoinOutput;
}
SMJState::JoinOutput => {
- let output_indices = self.join_partial()?;
- if !output_indices.is_empty() {
- self.output_partial(&output_indices)?;
- }
+ self.join_partial()?;
if self.output_size < self.batch_size {
if self.buffered_data.scanning_finished() {
@@ -439,11 +484,16 @@ impl Stream for SMJStream {
self.state = SMJState::Init;
}
} else {
- let record_batch = self.output_record_batch_and_reset()?;
- return Poll::Ready(Some(Ok(record_batch)));
+ self.freeze_all()?;
+ if !self.output_record_batches.is_empty() {
+ let record_batch = self.output_record_batch_and_reset()?;
+ return Poll::Ready(Some(Ok(record_batch)));
+ }
+ return Poll::Pending;
}
}
SMJState::Exhausted => {
+ self.freeze_all()?;
if !self.output_record_batches.is_empty() {
let record_batch = self.output_record_batch_and_reset()?;
return Poll::Ready(Some(Ok(record_batch)));
@@ -469,18 +519,18 @@ impl SMJStream {
batch_size: usize,
join_metrics: SortMergeJoinMetrics,
) -> Result<Self> {
+ let streamed_schema = streamed.schema();
+ let buffered_schema = buffered.schema();
Ok(Self {
state: SMJState::Init,
sort_options,
null_equals_null,
- schema: schema.clone(),
- streamed_schema: streamed.schema(),
- buffered_schema: buffered.schema(),
+ schema,
+ streamed_schema: streamed_schema.clone(),
+ buffered_schema,
streamed,
buffered,
- streamed_batch: RecordBatch::new_empty(schema),
- streamed_join_arrays: vec![],
- streamed_idx: 0,
+ streamed_batch: StreamedBatch::new_empty(streamed_schema),
buffered_data: BufferedData::default(),
streamed_joined: false,
buffered_joined: false,
@@ -502,8 +552,9 @@ impl SMJStream {
loop {
match &self.streamed_state {
StreamedState::Init => {
- if self.streamed_idx + 1 < self.streamed_batch.num_rows() {
- self.streamed_idx += 1;
+ if self.streamed_batch.idx + 1 < self.streamed_batch.batch.num_rows()
+ {
+ self.streamed_batch.idx += 1;
self.streamed_state = StreamedState::Ready;
return Poll::Ready(Some(Ok(())));
} else {
@@ -520,12 +571,11 @@ impl SMJStream {
}
Poll::Ready(Some(batch)) => {
if batch.num_rows() > 0 {
+ self.freeze_dequeuing_streamed()?;
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
- self.streamed_batch = batch;
- self.streamed_join_arrays =
- join_arrays(&self.streamed_batch, &self.on_streamed);
- self.streamed_idx = 0;
+ self.streamed_batch =
+ StreamedBatch::new(batch, &self.on_streamed);
self.streamed_state = StreamedState::Ready;
}
}
@@ -552,6 +602,7 @@ impl SMJStream {
while !self.buffered_data.batches.is_empty() {
let head_batch = self.buffered_data.head_batch();
if head_batch.range.end == head_batch.batch.num_rows() {
+ self.freeze_dequeuing_buffered()?;
self.buffered_data.batches.pop_front();
} else {
break;
@@ -650,8 +701,8 @@ impl SMJStream {
}
return compare_join_arrays(
- &self.streamed_join_arrays,
- self.streamed_idx,
+ &self.streamed_batch.join_arrays,
+ self.streamed_batch.idx,
&self.buffered_data.head_batch().join_arrays,
self.buffered_data.head_batch().range.start,
&self.sort_options,
@@ -661,7 +712,7 @@ impl SMJStream {
/// Produce join and fill output buffer until reaching target batch size
/// or the join is finished
- fn join_partial(&mut self) -> ArrowResult<Vec<OutputIndex>> {
+ fn join_partial(&mut self) -> ArrowResult<()> {
let mut join_streamed = false;
let mut join_buffered = false;
@@ -696,28 +747,32 @@ impl SMJStream {
if !join_streamed && !join_buffered {
// no joined data
self.buffered_data.scanning_finish();
- return Ok(vec![]);
+ return Ok(());
}
- let mut output_indices = vec![];
-
if join_buffered {
// joining streamed/nulls and buffered
- let streamed_idx = if join_streamed {
- Some(self.streamed_idx)
- } else {
- None
- };
while !self.buffered_data.scanning_finished()
&& self.output_size < self.batch_size
{
- output_indices.push(OutputIndex {
- streamed_idx,
- buffered_idx: Some((
- self.buffered_data.scanning_batch_idx,
- self.buffered_data.scanning_idx(),
- )),
- });
+ let scanning_idx = self.buffered_data.scanning_idx();
+ if join_streamed {
+ self.buffered_data
+ .scanning_batch_mut()
+ .pair_joined
+ .0
+ .push(self.streamed_batch.idx);
+ self.buffered_data
+ .scanning_batch_mut()
+ .pair_joined
+ .1
+ .push(scanning_idx);
+ } else {
+ self.buffered_data
+ .scanning_batch_mut()
+ .null_joined
+ .push(scanning_idx);
+ }
self.output_size += 1;
self.buffered_data.scanning_advance();
@@ -728,194 +783,112 @@ impl SMJStream {
}
} else {
// joining streamed and nulls
- output_indices.push(OutputIndex {
- streamed_idx: Some(self.streamed_idx),
- buffered_idx: None,
- });
+ self.streamed_batch
+ .null_joined
+ .push(self.streamed_batch.idx);
self.output_size += 1;
self.buffered_data.scanning_finish();
self.streamed_joined = true;
}
- Ok(output_indices)
+ Ok(())
}
- fn output_record_batch_and_reset(&mut self) -> ArrowResult<RecordBatch> {
- assert!(!self.output_record_batches.is_empty());
-
- let record_batch =
- combine_batches(&self.output_record_batches, self.schema.clone())?.unwrap();
- self.join_metrics.output_batches.add(1);
- self.join_metrics.output_rows.add(record_batch.num_rows());
- self.output_size = 0;
- self.output_record_batches.clear();
- Ok(record_batch)
+ fn freeze_all(&mut self) -> ArrowResult<()> {
+ self.freeze_streamed_join_null()?;
+ self.freeze_buffered_join_null(self.buffered_data.batches.len())?;
+ self.freeze_buffered_join_streamed(self.buffered_data.batches.len())?;
+ Ok(())
}
- fn output_partial(&mut self, output_indices: &[OutputIndex]) -> ArrowResult<()> {
- match self.join_type {
- JoinType::Inner => {
- self.output_partial_streamed_joining_buffered(output_indices)?;
- }
- JoinType::Left | JoinType::Right => {
- self.output_partial_streamed_joining_buffered(output_indices)?;
- self.output_partial_streamed_joining_null(output_indices)?;
- }
- JoinType::Full => {
- self.output_partial_streamed_joining_buffered(output_indices)?;
- self.output_partial_streamed_joining_null(output_indices)?;
- self.output_partial_null_joining_buffered(output_indices)?;
- }
- JoinType::Semi | JoinType::Anti => {
- self.output_partial_streamed_joining_null(output_indices)?;
- }
- }
+ // freeze when a dequeueing streamed batch
+ fn freeze_dequeuing_streamed(&mut self) -> ArrowResult<()> {
+ self.freeze_streamed_join_null()?;
+ self.freeze_buffered_join_streamed(self.buffered_data.batches.len())?;
Ok(())
}
- fn output_partial_streamed_joining_buffered(
- &mut self,
- output_indices: &[OutputIndex],
- ) -> ArrowResult<()> {
- let mut output = |buffered_batch_idx: usize, indices: &[OutputIndex]| {
- if indices.is_empty() {
- return ArrowResult::Ok(());
- }
-
- // take streamed columns
- let streamed_indices = UInt64Array::from_iter_values(
- indices
- .iter()
- .map(|index| index.streamed_idx.unwrap() as u64),
- );
- let mut streamed_columns = self
- .streamed_batch
- .columns()
- .iter()
- .map(|column| take(column, &streamed_indices, None))
- .collect::<ArrowResult<Vec<_>>>()?;
-
- // take buffered columns
- let buffered_indices = UInt64Array::from_iter_values(
- indices
- .iter()
- .map(|index| index.buffered_idx.unwrap().1 as u64),
- );
- let mut buffered_columns = self.buffered_data.batches[buffered_batch_idx]
- .batch
- .columns()
- .iter()
- .map(|column| take(column, &buffered_indices, None))
- .collect::<ArrowResult<Vec<_>>>()?;
-
- // combine columns and produce record batch
- let columns = match self.join_type {
- JoinType::Inner | JoinType::Left | JoinType::Full => {
- streamed_columns.extend(buffered_columns);
- streamed_columns
- }
- JoinType::Right => {
- buffered_columns.extend(streamed_columns);
- buffered_columns
- }
- JoinType::Semi | JoinType::Anti => {
- unreachable!()
- }
- };
- let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?;
- self.output_record_batches.push(record_batch);
- Ok(())
- };
-
- let mut buffered_batch_idx = 0;
- let mut indices = vec![];
- for &index in output_indices
- .iter()
- .filter(|index| index.streamed_idx.is_some())
- .filter(|index| index.buffered_idx.is_some())
- {
- let buffered_idx = index.buffered_idx.unwrap();
- if index.buffered_idx.unwrap().0 != buffered_batch_idx {
- output(buffered_batch_idx, &indices)?;
- buffered_batch_idx = buffered_idx.0;
- indices.clear();
- }
- indices.push(index);
- }
- output(buffered_batch_idx, &indices)?;
+ // freeze when a dequeueing streamed batch
+ fn freeze_dequeuing_buffered(&mut self) -> ArrowResult<()> {
+ self.freeze_buffered_join_streamed(1)?;
+ self.freeze_buffered_join_null(1)?;
Ok(())
}
- fn output_partial_streamed_joining_null(
- &mut self,
- output_indices: &[OutputIndex],
- ) -> ArrowResult<()> {
- // streamed joining null
+ // join_type must be one of: `Left`/`Right`/`Full`/`Semi`/`Anti`
+ fn freeze_streamed_join_null(&mut self) -> ArrowResult<()> {
+ if !matches!(
+ self.join_type,
+ JoinType::Left
+ | JoinType::Right
+ | JoinType::Full
+ | JoinType::Semi
+ | JoinType::Anti
+ ) {
+ return Ok(());
+ }
let streamed_indices = UInt64Array::from_iter_values(
- output_indices
+ self.streamed_batch
+ .null_joined
.iter()
- .filter(|index| index.streamed_idx.is_some())
- .filter(|index| index.buffered_idx.is_none())
- .map(|index| index.streamed_idx.unwrap() as u64),
+ .map(|&index| index as u64),
);
+ if streamed_indices.is_empty() {
+ return Ok(());
+ }
+ self.streamed_batch.null_joined.clear();
+
let mut streamed_columns = self
.streamed_batch
+ .batch
.columns()
.iter()
.map(|column| take(column, &streamed_indices, None))
.collect::<ArrowResult<Vec<_>>>()?;
- let mut buffered_columns = self
- .buffered_schema
- .fields()
- .iter()
- .map(|f| new_null_array(f.data_type(), streamed_indices.len()))
- .collect::<Vec<_>>();
+ let columns = if matches!(self.join_type, JoinType::Semi | JoinType::Anti) {
+ streamed_columns
+ } else {
+ let mut buffered_columns = self
+ .buffered_schema
+ .fields()
+ .iter()
+ .map(|f| new_null_array(f.data_type(), streamed_indices.len()))
+ .collect::<Vec<_>>();
- let columns = match self.join_type {
- JoinType::Inner => {
- unreachable!()
- }
- JoinType::Left | JoinType::Full => {
- streamed_columns.extend(buffered_columns);
- streamed_columns
- }
- JoinType::Right => {
+ if matches!(self.join_type, JoinType::Right) {
buffered_columns.extend(streamed_columns);
buffered_columns
+ } else {
+ streamed_columns.extend(buffered_columns);
+ streamed_columns
}
- JoinType::Anti | JoinType::Semi => streamed_columns,
};
-
- if !streamed_indices.is_empty() {
- let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?;
- self.output_record_batches.push(record_batch);
- }
+ self.output_record_batches
+ .push(RecordBatch::try_new(self.schema.clone(), columns)?);
Ok(())
}
- fn output_partial_null_joining_buffered(
- &mut self,
- output_indices: &[OutputIndex],
- ) -> ArrowResult<()> {
- let mut output = |buffered_batch_idx: usize, indices: &[OutputIndex]| {
- if indices.is_empty() {
- return ArrowResult::Ok(());
- }
-
- // take buffered columns
+ // join_type must be `Full`
+ fn freeze_buffered_join_null(&mut self, batch_count: usize) -> ArrowResult<()> {
+ if !matches!(self.join_type, JoinType::Full) {
+ return Ok(());
+ }
+ for buffered_batch in self.buffered_data.batches.range_mut(..batch_count) {
let buffered_indices = UInt64Array::from_iter_values(
- indices
- .iter()
- .map(|index| index.buffered_idx.unwrap().1 as u64),
+ buffered_batch.null_joined.iter().map(|&index| index as u64),
);
- let buffered_columns = self.buffered_data.batches[buffered_batch_idx]
+ if buffered_indices.is_empty() {
+ continue;
+ }
+ buffered_batch.null_joined.clear();
+
+ let buffered_columns = buffered_batch
.batch
.columns()
.iter()
.map(|column| take(column, &buffered_indices, None))
.collect::<ArrowResult<Vec<_>>>()?;
- // create null streamed columns
let mut streamed_columns = self
.streamed_schema
.fields()
@@ -923,43 +896,82 @@ impl SMJStream {
.map(|f| new_null_array(f.data_type(), buffered_indices.len()))
.collect::<Vec<_>>();
- // combine columns and produce record batch
- let columns = match self.join_type {
- JoinType::Full => {
- streamed_columns.extend(buffered_columns);
- streamed_columns
- }
- JoinType::Inner
- | JoinType::Left
- | JoinType::Right
- | JoinType::Semi
- | JoinType::Anti => {
- unreachable!()
- }
- };
- let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?;
- self.output_record_batches.push(record_batch);
- Ok(())
- };
+ streamed_columns.extend(buffered_columns);
+ let columns = streamed_columns;
- let mut buffered_batch_idx = 0;
- let mut indices = vec![];
- for &index in output_indices
- .iter()
- .filter(|index| index.streamed_idx.is_none())
- .filter(|index| index.buffered_idx.is_some())
- {
- let buffered_idx = index.buffered_idx.unwrap();
- if buffered_idx.0 != buffered_batch_idx {
- output(buffered_batch_idx, &indices)?;
- buffered_batch_idx = buffered_idx.0;
- indices.clear();
+ self.output_record_batches
+ .push(RecordBatch::try_new(self.schema.clone(), columns)?);
+ }
+ Ok(())
+ }
+
+ // join_type must be `Inner`/`Left`/`Right`/`Full`
+ fn freeze_buffered_join_streamed(&mut self, batch_count: usize) -> ArrowResult<()> {
+ if !matches!(
+ self.join_type,
+ JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full
+ ) {
+ return Ok(());
+ }
+ for buffered_batch in self.buffered_data.batches.range_mut(..batch_count) {
+ let buffered_indices = UInt64Array::from_iter_values(
+ buffered_batch
+ .pair_joined
+ .1
+ .iter()
+ .map(|&index| index as u64),
+ );
+ let streamed_indices = UInt64Array::from_iter_values(
+ buffered_batch
+ .pair_joined
+ .0
+ .iter()
+ .map(|&index| index as u64),
+ );
+ if buffered_indices.is_empty() {
+ continue;
}
- indices.push(index);
+ buffered_batch.pair_joined.0.clear();
+ buffered_batch.pair_joined.1.clear();
+
+ let mut buffered_columns = buffered_batch
+ .batch
+ .columns()
+ .iter()
+ .map(|column| take(column, &buffered_indices, None))
+ .collect::<ArrowResult<Vec<_>>>()?;
+
+ let mut streamed_columns = self
+ .streamed_batch
+ .batch
+ .columns()
+ .iter()
+ .map(|column| take(column, &streamed_indices, None))
+ .collect::<ArrowResult<Vec<_>>>()?;
+
+ let columns = if matches!(self.join_type, JoinType::Right) {
+ buffered_columns.extend(streamed_columns);
+ buffered_columns
+ } else {
+ streamed_columns.extend(buffered_columns);
+ streamed_columns
+ };
+
+ self.output_record_batches
+ .push(RecordBatch::try_new(self.schema.clone(), columns)?);
}
- output(buffered_batch_idx, &indices)?;
Ok(())
}
+
+ fn output_record_batch_and_reset(&mut self) -> ArrowResult<RecordBatch> {
+ let record_batch =
+ combine_batches(&self.output_record_batches, self.schema.clone())?.unwrap();
+ self.join_metrics.output_batches.add(1);
+ self.join_metrics.output_rows.add(record_batch.num_rows());
+ self.output_size -= record_batch.num_rows();
+ self.output_record_batches.clear();
+ Ok(record_batch)
+ }
}
/// Buffered data contains all buffered batches with one unique join key
@@ -1006,6 +1018,10 @@ impl BufferedData {
&self.batches[self.scanning_batch_idx]
}
+ pub fn scanning_batch_mut(&mut self) -> &mut BufferedBatch {
+ &mut self.batches[self.scanning_batch_idx]
+ }
+
pub fn scanning_idx(&self) -> usize {
self.scanning_batch().range.start + self.scanning_offset
}
@@ -1024,14 +1040,6 @@ impl BufferedData {
}
}
-#[derive(Clone, Copy, Debug)]
-struct OutputIndex {
- /// joined streamed row index
- streamed_idx: Option<usize>,
- /// joined buffered batch index and row index
- buffered_idx: Option<(usize, usize)>,
-}
-
/// Get join array refs of given batch and join columns
fn join_arrays(batch: &RecordBatch, on_column: &[Column]) -> Vec<ArrayRef> {
on_column