You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/04/12 15:05:24 UTC

[GitHub] [arrow-datafusion] tustvold commented on a diff in pull request #2215: Remove tokio::spawn from HashAggregateExec (#2201)

tustvold commented on code in PR #2215:
URL: https://github.com/apache/arrow-datafusion/pull/2215#discussion_r848548276


##########
datafusion/core/src/physical_plan/hash_aggregate.rs:
##########
@@ -356,25 +354,130 @@ Example: average
 * Once all N record batches arrive, `merge` is performed, which builds a RecordBatch with N rows and 2 columns.
 * Finally, `get_value` returns an array with one entry computed from the state
 */
-pin_project! {
-    struct GroupedHashAggregateStream {
+struct GroupedHashAggregateStream {
+    schema: SchemaRef,
+    input: SendableRecordBatchStream,
+    mode: AggregateMode,
+    accumulators: Accumulators,
+    aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+
+    aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+    group_expr: Vec<Arc<dyn PhysicalExpr>>,
+
+    baseline_metrics: BaselineMetrics,
+    random_state: RandomState,
+    finished: bool,
+}
+
+impl GroupedHashAggregateStream {
+    /// Create a new HashAggregateStream
+    pub fn new(
+        mode: AggregateMode,
         schema: SchemaRef,
-        #[pin]
-        output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
-        finished: bool,
-        drop_helper: AbortOnDropSingle<()>,
+        group_expr: Vec<Arc<dyn PhysicalExpr>>,
+        aggr_expr: Vec<Arc<dyn AggregateExpr>>,
+        input: SendableRecordBatchStream,
+        baseline_metrics: BaselineMetrics,
+    ) -> Result<Self> {
+        let timer = baseline_metrics.elapsed_compute().timer();
+
+        // The expressions to evaluate the batch, one vec of expressions per aggregation.
+        // Assume create_schema() always put group columns in front of aggr columns, we set
+        // col_idx_base to group expression count.
+        let aggregate_expressions =
+            aggregate_expressions(&aggr_expr, &mode, group_expr.len())?;
+
+        timer.done();
+
+        Ok(Self {
+            schema,
+            mode,
+            input,
+            aggr_expr,
+            group_expr,
+            baseline_metrics,
+            aggregate_expressions,
+            accumulators: Default::default(),
+            random_state: Default::default(),
+            finished: false,
+        })
+    }
+}
+
+impl Stream for GroupedHashAggregateStream {
+    type Item = ArrowResult<RecordBatch>;
+
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        let this = &mut *self;
+        if this.finished {
+            return Poll::Ready(None);
+        }
+
+        let elapsed_compute = this.baseline_metrics.elapsed_compute();
+
+        loop {
+            let result = match ready!(this.input.poll_next_unpin(cx)) {
+                Some(Ok(batch)) => {
+                    let timer = elapsed_compute.timer();
+                    let result = group_aggregate_batch(
+                        &this.mode,
+                        &this.random_state,
+                        &this.group_expr,
+                        &this.aggr_expr,
+                        batch,
+                        &mut this.accumulators,
+                        &this.aggregate_expressions,
+                    );
+
+                    timer.done();
+
+                    match result {
+                        Ok(_) => continue,
+                        Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
+                    }
+                }
+                Some(Err(e)) => Err(e),
+                None => {
+                    this.finished = true;
+                    let timer = this.baseline_metrics.elapsed_compute().timer();
+                    let result = create_batch_from_map(
+                        &this.mode,
+                        &this.accumulators,
+                        this.group_expr.len(),
+                        &this.schema,
+                    )
+                    .record_output(&this.baseline_metrics);
+
+                    timer.done();
+                    result
+                }
+            };
+
+            this.finished = true;
+            return Poll::Ready(Some(result));
+        }
+    }
+}
+
+impl RecordBatchStream for GroupedHashAggregateStream {
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
     }
 }
 
+/// TODO: Make this a member function of [`GroupedHashAggregateStream`]

Review Comment:
   I left this as a TODO for a better diff, I'll follow up with a subsequent PR doing just this



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org