You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2020/10/16 11:18:48 UTC

[GitHub] [arrow] alamb commented on a change in pull request #8473: ARROW-10320 [Rust] [DataFusion] Migrated from batch iterators to batch streams.

alamb commented on a change in pull request #8473:
URL: https://github.com/apache/arrow/pull/8473#discussion_r506288365



##########
File path: rust/datafusion/src/datasource/memory.rs
##########
@@ -135,6 +134,7 @@ mod tests {
     use super::*;
     use arrow::array::Int32Array;
     use arrow::datatypes::{DataType, Field, Schema};
+    use futures::StreamExt;

Review comment:
       If we are worried about a new dependency on `futures` we maybe could upgrade the version of tokio instead and  use the `StreamExt` from `tokio` instead from `futures`: https://docs.rs/tokio/0.3.0/tokio/stream/trait.StreamExt.html
   
   
   
   

##########
File path: rust/datafusion/src/physical_plan/common.rs
##########
@@ -31,53 +32,58 @@ use array::{
 };
 use arrow::datatypes::{DataType, SchemaRef};
 use arrow::error::Result as ArrowResult;
-use arrow::record_batch::{RecordBatch, RecordBatchReader};
+use arrow::record_batch::RecordBatch;
 use arrow::{
     array::{self, ArrayRef},
     datatypes::Schema,
 };
+use futures::{Stream, TryStreamExt};
 
-/// Iterator over a vector of record batches
-pub struct RecordBatchIterator {
+/// Stream of record batches
+pub struct SizedRecordBatchStream {

Review comment:
       `BufferedRecordBatchStream` might be more descriptive of what this struct does

##########
File path: rust/datafusion/src/physical_plan/filter.rs
##########
@@ -98,72 +102,76 @@ impl ExecutionPlan for FilterExec {
         }
     }
 
-    async fn execute(&self, partition: usize) -> Result<SendableRecordBatchReader> {
-        Ok(Box::new(FilterExecIter {
+    async fn execute(&self, partition: usize) -> Result<SendableRecordBatchStream> {
+        Ok(Box::pin(FilterExecStream {
             schema: self.input.schema().clone(),
             predicate: self.predicate.clone(),
             input: self.input.execute(partition).await?,
         }))
     }
 }
 
-/// The FilterExec iterator wraps the input iterator and applies the predicate expression to
+/// The FilterExec streams wraps the input iterator and applies the predicate expression to
 /// determine which rows to include in its output batches
-struct FilterExecIter {
+struct FilterExecStream {
     /// Output schema, which is the same as the input schema for this operator
     schema: SchemaRef,
     /// The expression to filter on. This expression must evaluate to a boolean value.
     predicate: Arc<dyn PhysicalExpr>,
     /// The input partition to filter.
-    input: SendableRecordBatchReader,
+    input: SendableRecordBatchStream,
+}
+
+fn batch_filter(
+    batch: &RecordBatch,
+    predicate: &Arc<dyn PhysicalExpr>,
+) -> ArrowResult<RecordBatch> {
+    predicate
+        .evaluate(&batch)
+        .map_err(ExecutionError::into_arrow_external_error)
+        .and_then(|array| {
+            array
+                .as_any()
+                .downcast_ref::<BooleanArray>()
+                .ok_or(
+                    ExecutionError::InternalError(
+                        "Filter predicate evaluated to non-boolean value".to_string(),
+                    )
+                    .into_arrow_external_error(),
+                )
+                // apply predicate to each column
+                .and_then(|predicate| {
+                    batch
+                        .columns()
+                        .iter()
+                        .map(|column| filter(column.as_ref(), predicate))
+                        .collect::<ArrowResult<Vec<_>>>()
+                })
+        })
+        // build RecordBatch
+        .and_then(|columns| RecordBatch::try_new(batch.schema().clone(), columns))
 }
 
-impl Iterator for FilterExecIter {
+impl Stream for FilterExecStream {
     type Item = ArrowResult<RecordBatch>;
 
-    /// Get the next batch
-    fn next(&mut self) -> Option<ArrowResult<RecordBatch>> {
-        match self.input.next() {
-            Some(Ok(batch)) => {
-                // evaluate the filter predicate to get a boolean array indicating which rows
-                // to include in the output
-                Some(
-                    self.predicate
-                        .evaluate(&batch)
-                        .map_err(ExecutionError::into_arrow_external_error)
-                        .and_then(|array| {
-                            array
-                                .as_any()
-                                .downcast_ref::<BooleanArray>()
-                                .ok_or(
-                                    ExecutionError::InternalError(
-                                        "Filter predicate evaluated to non-boolean value"
-                                            .to_string(),
-                                    )
-                                    .into_arrow_external_error(),
-                                )
-                                // apply predicate to each column
-                                .and_then(|predicate| {
-                                    batch
-                                        .columns()
-                                        .iter()
-                                        .map(|column| filter(column.as_ref(), predicate))
-                                        .collect::<ArrowResult<Vec<_>>>()
-                                })
-                        })
-                        // build RecordBatch
-                        .and_then(|columns| {
-                            RecordBatch::try_new(batch.schema().clone(), columns)
-                        }),
-                )
-            }
+    fn poll_next(
+        mut self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        self.input.poll_next_unpin(cx).map(|x| match x {
+            Some(Ok(batch)) => Some(batch_filter(&batch, &self.predicate)),

Review comment:
       this is nice -- it means the code doesn't have to have read all the batches in order to filter each one

##########
File path: rust/datafusion/tests/user_defined_plan.rs
##########
@@ -468,51 +469,69 @@ fn accumulate_batch(
         .expect("Column 1 is not revenue");
 
     for row in 0..num_rows {
-        add_row(top_values, customer_id.value(row), revenue.value(row), k);
+        add_row(
+            &mut top_values,
+            customer_id.value(row),
+            revenue.value(row),
+            k,
+        );
     }
-    Ok(())
+    Ok(top_values)
 }
 
-impl Iterator for TopKReader {
+impl Stream for TopKReader {
     type Item = std::result::Result<RecordBatch, ArrowError>;
 
-    /// Reads the next `RecordBatch`.
-    fn next(&mut self) -> Option<Self::Item> {
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
         if self.done {
-            return None;
+            return Poll::Ready(None);
         }
-
-        // Hard coded implementation for sales / customer_id example
-        let mut top_values: BTreeMap<i64, String> = BTreeMap::new();
+        // this aggregates and thus returns a single RecordBatch.
+        self.done = true;
 
         // take this as immutable
-        let k = &self.k;
+        let k = self.k;
+        let schema = self.schema().clone();
 
-        self.input
+        let top_values = self
+            .input
             .as_mut()
-            .into_iter()
-            .map(|batch| accumulate_batch(&batch?, &mut top_values, k))
-            .collect::<Result<()>>()
-            .unwrap();
-
-        // make output by walking over the map backwards (so values are descending)
-        let (revenue, customer): (Vec<i64>, Vec<&String>) =
-            top_values.iter().rev().unzip();
-
-        let customer: Vec<&str> = customer.iter().map(|&s| &**s).collect();
+            // Hard coded implementation for sales / customer_id example as BTree
+            .try_fold(
+                BTreeMap::<i64, String>::new(),
+                move |top_values, batch| async move {
+                    accumulate_batch(&batch, top_values, &k)
+                        .map_err(ExecutionError::into_arrow_external_error)
+                },
+            );
+
+        let top_values = top_values.map(|top_values| match top_values {
+            Ok(top_values) => {
+                // make output by walking over the map backwards (so values are descending)
+                let (revenue, customer): (Vec<i64>, Vec<&String>) =
+                    top_values.iter().rev().unzip();
+
+                let customer: Vec<&str> = customer.iter().map(|&s| &**s).collect();
+                Ok(RecordBatch::try_new(
+                    schema,
+                    vec![
+                        Arc::new(StringArray::from(customer)),
+                        Arc::new(Int64Array::from(revenue)),
+                    ],
+                )?)
+            }
+            Err(e) => Err(e),
+        });
+        let mut top_values = Box::pin(top_values.into_stream());

Review comment:
       this is very cool to see how the user defined code got transformed to a `Stream`

##########
File path: rust/datafusion/src/physical_plan/merge.rs
##########
@@ -100,27 +103,29 @@ impl ExecutionPlan for MergeExec {
                 self.input.execute(0).await
             }
             _ => {
-                let tasks = (0..input_partitions)
-                    .map(|part_i| {
-                        let input = self.input.clone();
-                        tokio::spawn(async move {
-                            let it = input.execute(part_i).await?;
-                            common::collect(it)
-                        })
+                let tasks = (0..input_partitions).map(|part_i| {
+                    let input = self.input.clone();
+                    tokio::spawn(async move {
+                        let stream = input.execute(part_i).await?;
+                        common::collect(stream).await

Review comment:
       I am still concerned that these calls to `collect` effectively buffer all the input before producing any output -- this both will likely consume more memory than needed as well as being a 'pipeline breaker' (nothing that relies on the output of the `Merge` can run until *all* of the merge inputs have been produced). 
   
   We could peraps use `chain` here to fuse the streams together -- https://docs.rs/tokio/0.3.0/tokio/stream/trait.StreamExt.html#method.chain. Using chain would avoid the need to buffer the intermediate results (aka the collect), but  it would also likely cause the input partitions to run one after the other rather than in parallel_
   
   Another thought I had would be to use something fancier like a `channel` that all the substreams write to. 
   
   But in any event that could be done in some future PR -- I think this code is better than master, and a step forward. 
   
   

##########
File path: rust/datafusion/src/physical_plan/parquet.rs
##########
@@ -197,24 +197,27 @@ fn read_file(
     Ok(())
 }
 
-struct ParquetIterator {
+struct ParquetStream {
     schema: SchemaRef,
     response_rx: Receiver<Option<ArrowResult<RecordBatch>>>,
 }
 
-impl Iterator for ParquetIterator {
+impl Stream for ParquetStream {
     type Item = ArrowResult<RecordBatch>;
 
-    fn next(&mut self) -> Option<Self::Item> {
+    fn poll_next(
+        self: std::pin::Pin<&mut Self>,
+        _: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
         match self.response_rx.recv() {

Review comment:
       This is getting closer to `async` parquet reader 👍 

##########
File path: rust/datafusion/src/physical_plan/hash_aggregate.rs
##########
@@ -505,62 +513,88 @@ fn aggregate_batch(
 
             // 1.3
             match mode {
-                AggregateMode::Partial => accum.update_batch(values),
-                AggregateMode::Final => accum.merge_batch(values),
+                AggregateMode::Partial => {
+                    accum.update_batch(values)?;
+                }
+                AggregateMode::Final => {
+                    accum.merge_batch(values)?;
+                }
             }
+            Ok(accum)
         })
-        .collect::<Result<()>>()
+        .collect::<Result<Vec<_>>>()
 }
 
-impl Iterator for HashAggregateIterator {
+impl Stream for HashAggregateStream {
     type Item = ArrowResult<RecordBatch>;
 
-    fn next(&mut self) -> Option<Self::Item> {
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
         if self.finished {
-            return None;
+            return Poll::Ready(None);
         }
 
         // return single batch
         self.finished = true;
 
-        let mut accumulators = match create_accumulators(&self.aggr_expr) {
+        let accumulators = match create_accumulators(&self.aggr_expr) {
             Ok(e) => e,
-            Err(e) => return Some(Err(ExecutionError::into_arrow_external_error(e))),
+            Err(e) => {
+                return Poll::Ready(Some(Err(ExecutionError::into_arrow_external_error(
+                    e,
+                ))))
+            }
         };
 
         let expressions = match aggregate_expressions(&self.aggr_expr, &self.mode) {
             Ok(e) => e,
-            Err(e) => return Some(Err(ExecutionError::into_arrow_external_error(e))),
+            Err(e) => {
+                return Poll::Ready(Some(Err(ExecutionError::into_arrow_external_error(
+                    e,
+                ))))
+            }
         };
+        let expressions = Arc::new(expressions);
 
         let mode = self.mode;
         let schema = self.schema();
 
         // 1 for each batch, update / merge accumulators with the expressions' values
-        match self
+        // future is ready when all batches are computed
+        let future = self
             .input
             .as_mut()
-            .into_iter()
-            .map(|batch| {
-                aggregate_batch(&mode, &batch?, &mut accumulators, &expressions)
-                    .map_err(ExecutionError::into_arrow_external_error)
-            })
-            .collect::<ArrowResult<()>>()
-        {
-            Err(e) => return Some(Err(e)),
-            Ok(_) => {}
-        }
+            .try_fold(
+                // pass the expressions on every fold to handle closures' mutability
+                (accumulators, expressions),

Review comment:
       this is cool -- to incrementally accumulate the aggregates in the strems




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org