You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/28 21:03:01 UTC

[arrow-datafusion] branch master updated: feat: `ResourceExhausted` for memory limit in `AggregateStream` (#4405)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new dd3f72ad1 feat: `ResourceExhausted` for memory limit in `AggregateStream` (#4405)
dd3f72ad1 is described below

commit dd3f72ad13df3e3ab2efde73eba546012eaf10fd
Author: Marco Neumann <ma...@crepererum.net>
AuthorDate: Mon Nov 28 21:02:55 2022 +0000

    feat: `ResourceExhausted` for memory limit in `AggregateStream` (#4405)
    
    Closes #3940.
---
 .../core/src/physical_plan/aggregates/mod.rs       |  25 +++-
 .../src/physical_plan/aggregates/no_grouping.rs    | 160 ++++++++++++++-------
 2 files changed, 126 insertions(+), 59 deletions(-)

diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs
index 312a3263a..6d7c3c21b 100644
--- a/datafusion/core/src/physical_plan/aggregates/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/mod.rs
@@ -295,6 +295,8 @@ impl AggregateExec {
                 self.aggr_expr.clone(),
                 input,
                 baseline_metrics,
+                context,
+                partition,
             )?))
         } else if self.row_aggregate_supported() {
             Ok(StreamType::GroupedHashAggregateStreamV2(
@@ -737,7 +739,7 @@ mod tests {
     use arrow::error::{ArrowError, Result as ArrowResult};
     use arrow::record_batch::RecordBatch;
     use datafusion_common::{DataFusionError, Result, ScalarValue};
-    use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count};
+    use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count, Median};
     use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr};
     use futures::{FutureExt, Stream};
     use std::any::Any;
@@ -1131,12 +1133,20 @@ mod tests {
         );
         let task_ctx = session_ctx.task_ctx();
 
-        let groups = PhysicalGroupBy {
+        let groups_none = PhysicalGroupBy::default();
+        let groups_some = PhysicalGroupBy {
             expr: vec![(col("a", &input_schema)?, "a".to_string())],
             null_expr: vec![],
             groups: vec![vec![false]],
         };
 
+        // something that allocates within the aggregator
+        let aggregates_v0: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Median::new(
+            col("a", &input_schema)?,
+            "MEDIAN(a)".to_string(),
+            DataType::UInt32,
+        ))];
+
         // use slow-path in `hash.rs`
         let aggregates_v1: Vec<Arc<dyn AggregateExpr>> =
             vec![Arc::new(ApproxDistinct::new(
@@ -1152,10 +1162,14 @@ mod tests {
             DataType::Float64,
         ))];
 
-        for (version, aggregates) in [(1, aggregates_v1), (2, aggregates_v2)] {
+        for (version, groups, aggregates) in [
+            (0, groups_none, aggregates_v0),
+            (1, groups_some.clone(), aggregates_v1),
+            (2, groups_some, aggregates_v2),
+        ] {
             let partial_aggregate = Arc::new(AggregateExec::try_new(
                 AggregateMode::Partial,
-                groups.clone(),
+                groups,
                 aggregates,
                 input.clone(),
                 input_schema.clone(),
@@ -1165,6 +1179,9 @@ mod tests {
 
             // ensure that we really got the version we wanted
             match version {
+                0 => {
+                    assert!(matches!(stream, StreamType::AggregateStream(_)));
+                }
                 1 => {
                     assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_)));
                 }
diff --git a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs b/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
index f687c982c..8c3556bb6 100644
--- a/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
+++ b/datafusion/core/src/physical_plan/aggregates/no_grouping.rs
@@ -17,6 +17,9 @@
 
 //! Aggregate without grouping columns
 
+use crate::execution::context::TaskContext;
+use crate::execution::memory_manager::proxy::MemoryConsumerProxy;
+use crate::execution::MemoryConsumerId;
 use crate::physical_plan::aggregates::{
     aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem,
     AggregateMode,
@@ -28,22 +31,31 @@ use arrow::error::{ArrowError, Result as ArrowResult};
 use arrow::record_batch::RecordBatch;
 use datafusion_common::Result;
 use datafusion_physical_expr::{AggregateExpr, PhysicalExpr};
+use futures::stream::BoxStream;
 use std::sync::Arc;
 use std::task::{Context, Poll};
 
-use futures::{
-    ready,
-    stream::{Stream, StreamExt},
-};
+use futures::stream::{Stream, StreamExt};
 
 /// stream struct for aggregation without grouping columns
 pub(crate) struct AggregateStream {
+    stream: BoxStream<'static, ArrowResult<RecordBatch>>,
+    schema: SchemaRef,
+}
+
+/// Actual implementation of [`AggregateStream`].
+///
+/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem
+/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with
+/// [`futures::stream::unfold`]. The latter requires a state object, which is [`GroupedHashAggregateStreamV2Inner`].
+struct AggregateStreamInner {
     schema: SchemaRef,
     mode: AggregateMode,
     input: SendableRecordBatchStream,
     baseline_metrics: BaselineMetrics,
     aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
     accumulators: Vec<AccumulatorItem>,
+    memory_consumer: MemoryConsumerProxy,
     finished: bool,
 }
 
@@ -55,19 +67,87 @@ impl AggregateStream {
         aggr_expr: Vec<Arc<dyn AggregateExpr>>,
         input: SendableRecordBatchStream,
         baseline_metrics: BaselineMetrics,
+        context: Arc<TaskContext>,
+        partition: usize,
     ) -> datafusion_common::Result<Self> {
         let aggregate_expressions = aggregate_expressions(&aggr_expr, &mode, 0)?;
         let accumulators = create_accumulators(&aggr_expr)?;
-
-        Ok(Self {
-            schema,
+        let memory_consumer = MemoryConsumerProxy::new(
+            "AggregationState",
+            MemoryConsumerId::new(partition),
+            Arc::clone(&context.runtime_env().memory_manager),
+        );
+
+        let inner = AggregateStreamInner {
+            schema: Arc::clone(&schema),
             mode,
             input,
             baseline_metrics,
             aggregate_expressions,
             accumulators,
+            memory_consumer,
             finished: false,
-        })
+        };
+        let stream = futures::stream::unfold(inner, |mut this| async move {
+            if this.finished {
+                return None;
+            }
+
+            let elapsed_compute = this.baseline_metrics.elapsed_compute();
+
+            loop {
+                let result = match this.input.next().await {
+                    Some(Ok(batch)) => {
+                        let timer = elapsed_compute.timer();
+                        let result = aggregate_batch(
+                            &this.mode,
+                            &batch,
+                            &mut this.accumulators,
+                            &this.aggregate_expressions,
+                        );
+
+                        timer.done();
+
+                        // allocate memory
+                        // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
+                        // overshooting a bit. Also this means we either store the whole record batch or not.
+                        let result = match result {
+                            Ok(allocated) => this.memory_consumer.alloc(allocated).await,
+                            Err(e) => Err(e),
+                        };
+
+                        match result {
+                            Ok(_) => continue,
+                            Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
+                        }
+                    }
+                    Some(Err(e)) => Err(e),
+                    None => {
+                        this.finished = true;
+                        let timer = this.baseline_metrics.elapsed_compute().timer();
+                        let result = finalize_aggregation(&this.accumulators, &this.mode)
+                            .map_err(|e| ArrowError::ExternalError(Box::new(e)))
+                            .and_then(|columns| {
+                                RecordBatch::try_new(this.schema.clone(), columns)
+                            })
+                            .record_output(&this.baseline_metrics);
+
+                        timer.done();
+
+                        result
+                    }
+                };
+
+                this.finished = true;
+                return Some((result, this));
+            }
+        });
+
+        // seems like some consumers call this stream even after it returned `None`, so let's fuse the stream.
+        let stream = stream.fuse();
+        let stream = Box::pin(stream);
+
+        Ok(Self { schema, stream })
     }
 }
 
@@ -79,49 +159,7 @@ impl Stream for AggregateStream {
         cx: &mut Context<'_>,
     ) -> Poll<Option<Self::Item>> {
         let this = &mut *self;
-        if this.finished {
-            return Poll::Ready(None);
-        }
-
-        let elapsed_compute = this.baseline_metrics.elapsed_compute();
-
-        loop {
-            let result = match ready!(this.input.poll_next_unpin(cx)) {
-                Some(Ok(batch)) => {
-                    let timer = elapsed_compute.timer();
-                    let result = aggregate_batch(
-                        &this.mode,
-                        &batch,
-                        &mut this.accumulators,
-                        &this.aggregate_expressions,
-                    );
-
-                    timer.done();
-
-                    match result {
-                        Ok(_) => continue,
-                        Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
-                    }
-                }
-                Some(Err(e)) => Err(e),
-                None => {
-                    this.finished = true;
-                    let timer = this.baseline_metrics.elapsed_compute().timer();
-                    let result = finalize_aggregation(&this.accumulators, &this.mode)
-                        .map_err(|e| ArrowError::ExternalError(Box::new(e)))
-                        .and_then(|columns| {
-                            RecordBatch::try_new(this.schema.clone(), columns)
-                        })
-                        .record_output(&this.baseline_metrics);
-
-                    timer.done();
-                    result
-                }
-            };
-
-            this.finished = true;
-            return Poll::Ready(Some(result));
-        }
+        this.stream.poll_next_unpin(cx)
     }
 }
 
@@ -131,13 +169,19 @@ impl RecordBatchStream for AggregateStream {
     }
 }
 
+/// Perform group-by aggregation for the given [`RecordBatch`].
+///
+/// If successfull, this returns the additional number of bytes that were allocated during this process.
+///
 /// TODO: Make this a member function
 fn aggregate_batch(
     mode: &AggregateMode,
     batch: &RecordBatch,
     accumulators: &mut [AccumulatorItem],
     expressions: &[Vec<Arc<dyn PhysicalExpr>>],
-) -> Result<()> {
+) -> Result<usize> {
+    let mut allocated = 0usize;
+
     // 1.1 iterate accumulators and respective expressions together
     // 1.2 evaluate expressions
     // 1.3 update / merge accumulators with the expressions' values
@@ -155,11 +199,17 @@ fn aggregate_batch(
                 .collect::<Result<Vec<_>>>()?;
 
             // 1.3
-            match mode {
+            let size_pre = accum.size();
+            let res = match mode {
                 AggregateMode::Partial => accum.update_batch(values),
                 AggregateMode::Final | AggregateMode::FinalPartitioned => {
                     accum.merge_batch(values)
                 }
-            }
-        })
+            };
+            let size_post = accum.size();
+            allocated += size_post.saturating_sub(size_pre);
+            res
+        })?;
+
+    Ok(allocated)
 }