You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2020/10/16 12:24:33 UTC

[GitHub] [arrow] vertexclique commented on a change in pull request #8473: ARROW-10320 [Rust] [DataFusion] Migrated from batch iterators to batch streams.

vertexclique commented on a change in pull request #8473:
URL: https://github.com/apache/arrow/pull/8473#discussion_r506357109



##########
File path: rust/datafusion/src/physical_plan/hash_aggregate.rs
##########
@@ -331,72 +337,74 @@ impl GroupedHashAggregateIterator {
 }
 
 type AccumulatorSet = Vec<Box<dyn Accumulator>>;
+type Accumulators = FnvHashMap<Vec<GroupByScalar>, (AccumulatorSet, Box<Vec<u32>>)>;
 
-impl Iterator for GroupedHashAggregateIterator {
+impl Stream for GroupedHashAggregateStream {
     type Item = ArrowResult<RecordBatch>;
 
-    fn next(&mut self) -> Option<Self::Item> {
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
         if self.finished {
-            return None;
+            return Poll::Ready(None);
         }
 
         // return single batch
         self.finished = true;
 
-        let mode = &self.mode;
-        let group_expr = &self.group_expr;
-        let aggr_expr = &self.aggr_expr;
+        let mode = self.mode.clone();
+        let group_expr = self.group_expr.clone();
+        let aggr_expr = self.aggr_expr.clone();
+        let schema = self.schema.clone();
 
         // the expressions to evaluate the batch, one vec of expressions per aggregation
         let aggregate_expressions = match aggregate_expressions(&aggr_expr, &mode) {
             Ok(e) => e,
-            Err(e) => return Some(Err(ExecutionError::into_arrow_external_error(e))),
+            Err(e) => {
+                return Poll::Ready(Some(Err(ExecutionError::into_arrow_external_error(
+                    e,
+                ))))
+            }
         };
 
         // mapping key -> (set of accumulators, indices of the key in the batch)
         // * the indexes are updated at each row
         // * the accumulators are updated at the end of each batch
         // * the indexes are `clear`ed at the end of each batch
-        let mut accumulators: FnvHashMap<
-            Vec<GroupByScalar>,
-            (AccumulatorSet, Box<Vec<u32>>),
-        > = FnvHashMap::default();
+        //let mut accumulators: Accumulators = FnvHashMap::default();
 
         // iterate over all input batches and update the accumulators
-        match self
-            .input
-            .as_mut()
-            .into_iter()
-            .map(|batch| {
+        let future = self.input.as_mut().try_fold(
+            Accumulators::default(),
+            |accumulators, batch| async {
                 group_aggregate_batch(
                     &mode,
                     &group_expr,
                     &aggr_expr,
-                    &batch?,
-                    &mut accumulators,
+                    batch,
+                    accumulators,
                     &aggregate_expressions,
                 )
                 .map_err(ExecutionError::into_arrow_external_error)
-            })
-            .collect::<ArrowResult<()>>()
-        {
-            Err(e) => return Some(Err(e)),
-            Ok(_) => {}
-        }
+            },
+        );
 
-        Some(
-            create_batch_from_map(
-                &self.mode,
-                &accumulators,
-                self.group_expr.len(),
-                &self.schema,
-            )
-            .map_err(ExecutionError::into_arrow_external_error),
-        )
+        let future = future.map(|accumulators| match accumulators {
+            Ok(accumulators) => {
+                create_batch_from_map(&mode, &accumulators, group_expr.len(), &schema)
+            }
+            Err(e) => Err(e),
+        });

Review comment:
       Can you use `map` instead of pattern matching? since err escalation looks redundant.

##########
File path: rust/datafusion/src/physical_plan/hash_aggregate.rs
##########
@@ -505,62 +513,88 @@ fn aggregate_batch(
 
             // 1.3
             match mode {
-                AggregateMode::Partial => accum.update_batch(values),
-                AggregateMode::Final => accum.merge_batch(values),
+                AggregateMode::Partial => {
+                    accum.update_batch(values)?;
+                }
+                AggregateMode::Final => {
+                    accum.merge_batch(values)?;
+                }
             }
+            Ok(accum)
         })
-        .collect::<Result<()>>()
+        .collect::<Result<Vec<_>>>()
 }
 
-impl Iterator for HashAggregateIterator {
+impl Stream for HashAggregateStream {
     type Item = ArrowResult<RecordBatch>;
 
-    fn next(&mut self) -> Option<Self::Item> {
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
         if self.finished {
-            return None;
+            return Poll::Ready(None);
         }
 
         // return single batch
         self.finished = true;
 
-        let mut accumulators = match create_accumulators(&self.aggr_expr) {
+        let accumulators = match create_accumulators(&self.aggr_expr) {
             Ok(e) => e,
-            Err(e) => return Some(Err(ExecutionError::into_arrow_external_error(e))),
+            Err(e) => {
+                return Poll::Ready(Some(Err(ExecutionError::into_arrow_external_error(
+                    e,
+                ))))
+            }
         };
 
         let expressions = match aggregate_expressions(&self.aggr_expr, &self.mode) {
             Ok(e) => e,
-            Err(e) => return Some(Err(ExecutionError::into_arrow_external_error(e))),
+            Err(e) => {
+                return Poll::Ready(Some(Err(ExecutionError::into_arrow_external_error(
+                    e,
+                ))))
+            }
         };
+        let expressions = Arc::new(expressions);
 
         let mode = self.mode;
         let schema = self.schema();
 
         // 1 for each batch, update / merge accumulators with the expressions' values
-        match self
+        // future is ready when all batches are computed
+        let future = self
             .input
             .as_mut()
-            .into_iter()
-            .map(|batch| {
-                aggregate_batch(&mode, &batch?, &mut accumulators, &expressions)
-                    .map_err(ExecutionError::into_arrow_external_error)
-            })
-            .collect::<ArrowResult<()>>()
-        {
-            Err(e) => return Some(Err(e)),
-            Ok(_) => {}
-        }
+            .try_fold(
+                // pass the expressions on every fold to handle closures' mutability
+                (accumulators, expressions),
+                |(acc, expr), batch| async move {
+                    aggregate_batch(&mode, &batch, acc, &expr)
+                        .map_err(ExecutionError::into_arrow_external_error)
+                        .map(|agg| (agg, expr))
+                },
+            )
+            // pick the accumulators (disregard the expressions)
+            .map(|e| e.map(|e| e.0));
+
+        let future = future.map(|b| {
+            match b {
+                Err(e) => return Err(e),
+                Ok(acc) => {
+                    // 2 convert values to a record batch

Review comment:
       2 is `to` i presume. Right :) ?

##########
File path: rust/datafusion/src/physical_plan/hash_aggregate.rs
##########
@@ -505,62 +513,88 @@ fn aggregate_batch(
 
             // 1.3
             match mode {
-                AggregateMode::Partial => accum.update_batch(values),
-                AggregateMode::Final => accum.merge_batch(values),
+                AggregateMode::Partial => {
+                    accum.update_batch(values)?;
+                }
+                AggregateMode::Final => {
+                    accum.merge_batch(values)?;
+                }
             }
+            Ok(accum)
         })
-        .collect::<Result<()>>()
+        .collect::<Result<Vec<_>>>()
 }
 
-impl Iterator for HashAggregateIterator {
+impl Stream for HashAggregateStream {
     type Item = ArrowResult<RecordBatch>;
 
-    fn next(&mut self) -> Option<Self::Item> {
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
         if self.finished {
-            return None;
+            return Poll::Ready(None);
         }
 
         // return single batch
         self.finished = true;
 
-        let mut accumulators = match create_accumulators(&self.aggr_expr) {
+        let accumulators = match create_accumulators(&self.aggr_expr) {
             Ok(e) => e,
-            Err(e) => return Some(Err(ExecutionError::into_arrow_external_error(e))),
+            Err(e) => {
+                return Poll::Ready(Some(Err(ExecutionError::into_arrow_external_error(
+                    e,
+                ))))
+            }
         };
 
         let expressions = match aggregate_expressions(&self.aggr_expr, &self.mode) {
             Ok(e) => e,
-            Err(e) => return Some(Err(ExecutionError::into_arrow_external_error(e))),
+            Err(e) => {
+                return Poll::Ready(Some(Err(ExecutionError::into_arrow_external_error(
+                    e,
+                ))))
+            }
         };
+        let expressions = Arc::new(expressions);
 
         let mode = self.mode;
         let schema = self.schema();
 
         // 1 for each batch, update / merge accumulators with the expressions' values
-        match self
+        // future is ready when all batches are computed
+        let future = self
             .input
             .as_mut()
-            .into_iter()
-            .map(|batch| {
-                aggregate_batch(&mode, &batch?, &mut accumulators, &expressions)
-                    .map_err(ExecutionError::into_arrow_external_error)
-            })
-            .collect::<ArrowResult<()>>()
-        {
-            Err(e) => return Some(Err(e)),
-            Ok(_) => {}
-        }
+            .try_fold(
+                // pass the expressions on every fold to handle closures' mutability
+                (accumulators, expressions),
+                |(acc, expr), batch| async move {
+                    aggregate_batch(&mode, &batch, acc, &expr)
+                        .map_err(ExecutionError::into_arrow_external_error)
+                        .map(|agg| (agg, expr))
+                },
+            )
+            // pick the accumulators (disregard the expressions)
+            .map(|e| e.map(|e| e.0));
+
+        let future = future.map(|b| {
+            match b {
+                Err(e) => return Err(e),
+                Ok(acc) => {
+                    // 2 convert values to a record batch
+                    finalize_aggregation(&acc, &mode)
+                        .map_err(ExecutionError::into_arrow_external_error)
+                        .and_then(|columns| RecordBatch::try_new(schema.clone(), columns))
+                }
+            }
+        });

Review comment:
       Same map simplification here probably.

##########
File path: rust/datafusion/src/datasource/memory.rs
##########
@@ -135,6 +134,7 @@ mod tests {
     use super::*;
     use arrow::array::Int32Array;
     use arrow::datatypes::{DataType, Field, Schema};
+    use futures::StreamExt;

Review comment:
       Please not do that. Futures interfaces are not compatible with other APIs in the ecosystem. This is a good implementation which is adaptable to tokio, async-std, bastion and others.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org