You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/05/09 10:10:39 UTC

[GitHub] [arrow-datafusion] yjshen commented on a diff in pull request #2492: Smj optimize

yjshen commented on code in PR #2492:
URL: https://github.com/apache/arrow-datafusion/pull/2492#discussion_r867850581


##########
datafusion/core/src/physical_plan/sort_merge_join.rs:
##########
@@ -730,238 +785,176 @@ impl SMJStream {
             }
         } else {
             // joining streamed and nulls
-            output_indices.push(OutputIndex {
-                streamed_idx: Some(self.streamed_idx),
-                buffered_idx: None,
-            });
+            self.streamed_batch
+                .null_joined
+                .push(self.streamed_batch.idx);
             self.output_size += 1;
             self.buffered_data.scanning_finish();
             self.streamed_joined = true;
         }
-        Ok(output_indices)
+        Ok(())
     }
 
-    fn output_record_batch_and_reset(&mut self) -> ArrowResult<RecordBatch> {
-        assert!(!self.output_record_batches.is_empty());
-
-        let record_batch =
-            combine_batches(&self.output_record_batches, self.schema.clone())?.unwrap();
-        self.join_metrics.output_batches.add(1);
-        self.join_metrics.output_rows.add(record_batch.num_rows());
-        self.output_size = 0;
-        self.output_record_batches.clear();
-        Ok(record_batch)
+    fn freeze_all(&mut self) -> ArrowResult<()> {
+        self.freeze_streamed_join_null()?;
+        self.freeze_buffered_join_null(self.buffered_data.batches.len())?;
+        self.freeze_buffered_join_streamed(self.buffered_data.batches.len())?;
+        Ok(())
     }
 
-    fn output_partial(&mut self, output_indices: &[OutputIndex]) -> ArrowResult<()> {
-        match self.join_type {
-            JoinType::Inner => {
-                self.output_partial_streamed_joining_buffered(output_indices)?;
-            }
-            JoinType::Left | JoinType::Right => {
-                self.output_partial_streamed_joining_buffered(output_indices)?;
-                self.output_partial_streamed_joining_null(output_indices)?;
-            }
-            JoinType::Full => {
-                self.output_partial_streamed_joining_buffered(output_indices)?;
-                self.output_partial_streamed_joining_null(output_indices)?;
-                self.output_partial_null_joining_buffered(output_indices)?;
-            }
-            JoinType::Semi | JoinType::Anti => {
-                self.output_partial_streamed_joining_null(output_indices)?;
-            }
-        }
+    // freeze when a dequeueing streamed batch
+    fn freeze_dequeuing_streamed(&mut self) -> ArrowResult<()> {
+        self.freeze_streamed_join_null()?;
+        self.freeze_buffered_join_streamed(self.buffered_data.batches.len())?;
         Ok(())
     }
 
-    fn output_partial_streamed_joining_buffered(
-        &mut self,
-        output_indices: &[OutputIndex],
-    ) -> ArrowResult<()> {
-        let mut output = |buffered_batch_idx: usize, indices: &[OutputIndex]| {
-            if indices.is_empty() {
-                return ArrowResult::Ok(());
-            }
-
-            // take streamed columns
-            let streamed_indices = UInt64Array::from_iter_values(
-                indices
-                    .iter()
-                    .map(|index| index.streamed_idx.unwrap() as u64),
-            );
-            let mut streamed_columns = self
-                .streamed_batch
-                .columns()
-                .iter()
-                .map(|column| take(column, &streamed_indices, None))
-                .collect::<ArrowResult<Vec<_>>>()?;
-
-            // take buffered columns
-            let buffered_indices = UInt64Array::from_iter_values(
-                indices
-                    .iter()
-                    .map(|index| index.buffered_idx.unwrap().1 as u64),
-            );
-            let mut buffered_columns = self.buffered_data.batches[buffered_batch_idx]
-                .batch
-                .columns()
-                .iter()
-                .map(|column| take(column, &buffered_indices, None))
-                .collect::<ArrowResult<Vec<_>>>()?;
-
-            // combine columns and produce record batch
-            let columns = match self.join_type {
-                JoinType::Inner | JoinType::Left | JoinType::Full => {
-                    streamed_columns.extend(buffered_columns);
-                    streamed_columns
-                }
-                JoinType::Right => {
-                    buffered_columns.extend(streamed_columns);
-                    buffered_columns
-                }
-                JoinType::Semi | JoinType::Anti => {
-                    unreachable!()
-                }
-            };
-            let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?;
-            self.output_record_batches.push(record_batch);
-            Ok(())
-        };
-
-        let mut buffered_batch_idx = 0;
-        let mut indices = vec![];
-        for &index in output_indices
-            .iter()
-            .filter(|index| index.streamed_idx.is_some())
-            .filter(|index| index.buffered_idx.is_some())
-        {
-            let buffered_idx = index.buffered_idx.unwrap();
-            if index.buffered_idx.unwrap().0 != buffered_batch_idx {
-                output(buffered_batch_idx, &indices)?;
-                buffered_batch_idx = buffered_idx.0;
-                indices.clear();
-            }
-            indices.push(index);
-        }
-        output(buffered_batch_idx, &indices)?;
+    // freeze when a dequeueing streamed batch
+    fn freeze_dequeuing_buffered(&mut self) -> ArrowResult<()> {
+        self.freeze_buffered_join_streamed(1)?;
+        self.freeze_buffered_join_null(1)?;
         Ok(())
     }
 
-    fn output_partial_streamed_joining_null(
-        &mut self,
-        output_indices: &[OutputIndex],
-    ) -> ArrowResult<()> {
-        // streamed joining null
+    // join_type must be one of: `Left`/`Right`/`Full`/`Semi`/`Anti`
+    fn freeze_streamed_join_null(&mut self) -> ArrowResult<()> {

Review Comment:
   Shall we pass in the join_type and add an assertion here?



##########
datafusion/core/src/physical_plan/sort_merge_join.rs:
##########
@@ -730,238 +785,176 @@ impl SMJStream {
             }
         } else {
             // joining streamed and nulls
-            output_indices.push(OutputIndex {
-                streamed_idx: Some(self.streamed_idx),
-                buffered_idx: None,
-            });
+            self.streamed_batch
+                .null_joined
+                .push(self.streamed_batch.idx);
             self.output_size += 1;
             self.buffered_data.scanning_finish();
             self.streamed_joined = true;
         }
-        Ok(output_indices)
+        Ok(())
     }
 
-    fn output_record_batch_and_reset(&mut self) -> ArrowResult<RecordBatch> {
-        assert!(!self.output_record_batches.is_empty());
-
-        let record_batch =
-            combine_batches(&self.output_record_batches, self.schema.clone())?.unwrap();
-        self.join_metrics.output_batches.add(1);
-        self.join_metrics.output_rows.add(record_batch.num_rows());
-        self.output_size = 0;
-        self.output_record_batches.clear();
-        Ok(record_batch)
+    fn freeze_all(&mut self) -> ArrowResult<()> {
+        self.freeze_streamed_join_null()?;
+        self.freeze_buffered_join_null(self.buffered_data.batches.len())?;
+        self.freeze_buffered_join_streamed(self.buffered_data.batches.len())?;
+        Ok(())
     }
 
-    fn output_partial(&mut self, output_indices: &[OutputIndex]) -> ArrowResult<()> {
-        match self.join_type {
-            JoinType::Inner => {
-                self.output_partial_streamed_joining_buffered(output_indices)?;
-            }
-            JoinType::Left | JoinType::Right => {
-                self.output_partial_streamed_joining_buffered(output_indices)?;
-                self.output_partial_streamed_joining_null(output_indices)?;
-            }
-            JoinType::Full => {
-                self.output_partial_streamed_joining_buffered(output_indices)?;
-                self.output_partial_streamed_joining_null(output_indices)?;
-                self.output_partial_null_joining_buffered(output_indices)?;
-            }
-            JoinType::Semi | JoinType::Anti => {
-                self.output_partial_streamed_joining_null(output_indices)?;
-            }
-        }
+    // freeze when a dequeueing streamed batch
+    fn freeze_dequeuing_streamed(&mut self) -> ArrowResult<()> {
+        self.freeze_streamed_join_null()?;
+        self.freeze_buffered_join_streamed(self.buffered_data.batches.len())?;
         Ok(())
     }
 
-    fn output_partial_streamed_joining_buffered(
-        &mut self,
-        output_indices: &[OutputIndex],
-    ) -> ArrowResult<()> {
-        let mut output = |buffered_batch_idx: usize, indices: &[OutputIndex]| {
-            if indices.is_empty() {
-                return ArrowResult::Ok(());
-            }
-
-            // take streamed columns
-            let streamed_indices = UInt64Array::from_iter_values(
-                indices
-                    .iter()
-                    .map(|index| index.streamed_idx.unwrap() as u64),
-            );
-            let mut streamed_columns = self
-                .streamed_batch
-                .columns()
-                .iter()
-                .map(|column| take(column, &streamed_indices, None))
-                .collect::<ArrowResult<Vec<_>>>()?;
-
-            // take buffered columns
-            let buffered_indices = UInt64Array::from_iter_values(
-                indices
-                    .iter()
-                    .map(|index| index.buffered_idx.unwrap().1 as u64),
-            );
-            let mut buffered_columns = self.buffered_data.batches[buffered_batch_idx]
-                .batch
-                .columns()
-                .iter()
-                .map(|column| take(column, &buffered_indices, None))
-                .collect::<ArrowResult<Vec<_>>>()?;
-
-            // combine columns and produce record batch
-            let columns = match self.join_type {
-                JoinType::Inner | JoinType::Left | JoinType::Full => {
-                    streamed_columns.extend(buffered_columns);
-                    streamed_columns
-                }
-                JoinType::Right => {
-                    buffered_columns.extend(streamed_columns);
-                    buffered_columns
-                }
-                JoinType::Semi | JoinType::Anti => {
-                    unreachable!()
-                }
-            };
-            let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?;
-            self.output_record_batches.push(record_batch);
-            Ok(())
-        };
-
-        let mut buffered_batch_idx = 0;
-        let mut indices = vec![];
-        for &index in output_indices
-            .iter()
-            .filter(|index| index.streamed_idx.is_some())
-            .filter(|index| index.buffered_idx.is_some())
-        {
-            let buffered_idx = index.buffered_idx.unwrap();
-            if index.buffered_idx.unwrap().0 != buffered_batch_idx {
-                output(buffered_batch_idx, &indices)?;
-                buffered_batch_idx = buffered_idx.0;
-                indices.clear();
-            }
-            indices.push(index);
-        }
-        output(buffered_batch_idx, &indices)?;
+    // freeze when a dequeueing streamed batch
+    fn freeze_dequeuing_buffered(&mut self) -> ArrowResult<()> {
+        self.freeze_buffered_join_streamed(1)?;
+        self.freeze_buffered_join_null(1)?;
         Ok(())
     }
 
-    fn output_partial_streamed_joining_null(
-        &mut self,
-        output_indices: &[OutputIndex],
-    ) -> ArrowResult<()> {
-        // streamed joining null
+    // join_type must be one of: `Left`/`Right`/`Full`/`Semi`/`Anti`
+    fn freeze_streamed_join_null(&mut self) -> ArrowResult<()> {
         let streamed_indices = UInt64Array::from_iter_values(
-            output_indices
+            self.streamed_batch
+                .null_joined
                 .iter()
-                .filter(|index| index.streamed_idx.is_some())
-                .filter(|index| index.buffered_idx.is_none())
-                .map(|index| index.streamed_idx.unwrap() as u64),
+                .map(|&index| index as u64),
         );
+        if streamed_indices.is_empty() {
+            return Ok(());
+        }
+        self.streamed_batch.null_joined.clear();
+
         let mut streamed_columns = self
             .streamed_batch
+            .batch
             .columns()
             .iter()
             .map(|column| take(column, &streamed_indices, None))
             .collect::<ArrowResult<Vec<_>>>()?;
 
-        let mut buffered_columns = self
-            .buffered_schema
-            .fields()
-            .iter()
-            .map(|f| new_null_array(f.data_type(), streamed_indices.len()))
-            .collect::<Vec<_>>();
+        let columns = if matches!(self.join_type, JoinType::Semi | JoinType::Anti) {
+            streamed_columns
+        } else {
+            let mut buffered_columns = self
+                .buffered_schema
+                .fields()
+                .iter()
+                .map(|f| new_null_array(f.data_type(), streamed_indices.len()))
+                .collect::<Vec<_>>();
 
-        let columns = match self.join_type {
-            JoinType::Inner => {
-                unreachable!()
-            }
-            JoinType::Left | JoinType::Full => {
-                streamed_columns.extend(buffered_columns);
-                streamed_columns
-            }
-            JoinType::Right => {
+            if matches!(self.join_type, JoinType::Right) {
                 buffered_columns.extend(streamed_columns);
                 buffered_columns
+            } else {
+                streamed_columns.extend(buffered_columns);
+                streamed_columns
             }
-            JoinType::Anti | JoinType::Semi => streamed_columns,
         };
-
-        if !streamed_indices.is_empty() {
-            let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?;
-            self.output_record_batches.push(record_batch);
-        }
+        self.output_record_batches
+            .push(RecordBatch::try_new(self.schema.clone(), columns)?);
         Ok(())
     }
 
-    fn output_partial_null_joining_buffered(
-        &mut self,
-        output_indices: &[OutputIndex],
-    ) -> ArrowResult<()> {
-        let mut output = |buffered_batch_idx: usize, indices: &[OutputIndex]| {
-            if indices.is_empty() {
-                return ArrowResult::Ok(());
-            }
-
-            // take buffered columns
+    // join_type must be `Full`
+    fn freeze_buffered_join_null(&mut self, batch_count: usize) -> ArrowResult<()> {
+        for buffered_batch in self.buffered_data.batches.range_mut(..batch_count) {
             let buffered_indices = UInt64Array::from_iter_values(
-                indices
-                    .iter()
-                    .map(|index| index.buffered_idx.unwrap().1 as u64),
+                buffered_batch.null_joined.iter().map(|&index| index as u64),
             );
-            let buffered_columns = self.buffered_data.batches[buffered_batch_idx]
+            if buffered_indices.is_empty() {
+                continue;
+            }
+            buffered_batch.null_joined.clear();
+
+            let buffered_columns = buffered_batch
                 .batch
                 .columns()
                 .iter()
                 .map(|column| take(column, &buffered_indices, None))
                 .collect::<ArrowResult<Vec<_>>>()?;
 
-            // create null streamed columns
             let mut streamed_columns = self
                 .streamed_schema
                 .fields()
                 .iter()
                 .map(|f| new_null_array(f.data_type(), buffered_indices.len()))
                 .collect::<Vec<_>>();
 
-            // combine columns and produce record batch
-            let columns = match self.join_type {
-                JoinType::Full => {
-                    streamed_columns.extend(buffered_columns);
-                    streamed_columns
-                }
-                JoinType::Inner
-                | JoinType::Left
-                | JoinType::Right
-                | JoinType::Semi
-                | JoinType::Anti => {
-                    unreachable!()
-                }
-            };
-            let record_batch = RecordBatch::try_new(self.schema.clone(), columns)?;
-            self.output_record_batches.push(record_batch);
-            Ok(())
-        };
+            streamed_columns.extend(buffered_columns);
+            let columns = streamed_columns;
 
-        let mut buffered_batch_idx = 0;
-        let mut indices = vec![];
-        for &index in output_indices
-            .iter()
-            .filter(|index| index.streamed_idx.is_none())
-            .filter(|index| index.buffered_idx.is_some())
-        {
-            let buffered_idx = index.buffered_idx.unwrap();
-            if buffered_idx.0 != buffered_batch_idx {
-                output(buffered_batch_idx, &indices)?;
-                buffered_batch_idx = buffered_idx.0;
-                indices.clear();
+            self.output_record_batches
+                .push(RecordBatch::try_new(self.schema.clone(), columns)?);
+        }
+        Ok(())
+    }
+
+    // join_type must be `Inner`/`Left`/`Right`/`Full`
+    fn freeze_buffered_join_streamed(&mut self, batch_count: usize) -> ArrowResult<()> {

Review Comment:
   Same here for join_type assertion



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org