You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/18 20:06:17 UTC

[arrow-datafusion] branch master updated: Return `ResourceExhausted` errors when memory limit is exceed in `GroupedHashAggregateStreamV2` (Row Hash) (#4202)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new f3a65c744 Return `ResourceExhausted` errors when memory limit is exceed in  `GroupedHashAggregateStreamV2` (Row Hash) (#4202)
f3a65c744 is described below

commit f3a65c74442fa42770418684a71a09ad9bcc348c
Author: Marco Neumann <ma...@crepererum.net>
AuthorDate: Fri Nov 18 20:06:11 2022 +0000

    Return `ResourceExhausted` errors when memory limit is exceed in  `GroupedHashAggregateStreamV2` (Row Hash) (#4202)
    
    * refactor: remove needless async
    
    * feat: wire memory management into `GroupedHashAggregateStreamV2`
    
    Most of it is refactoring to allow us to call the async memory subsystem
    while polling the stream. The actual memory accounting is rather easy
    (since it's only ever growing except when the stream is dropped).
    
    Helps with #3940. (not closing yet, also need to do V1)
    
    Performance Impact:
    -------------------
    
    ```text
    ❯ cargo bench -p datafusion --bench aggregate_query_sql -- --baseline issue3940a-pre
        Finished bench [optimized] target(s) in 0.08s
         Running benches/aggregate_query_sql.rs (target/release/deps/aggregate_query_sql-e9e315ab7a06a262)
    aggregate_query_no_group_by 15 12
                            time:   [654.77 µs 655.49 µs 656.29 µs]
                            change: [-1.6711% -1.2910% -0.8435%] (p = 0.00 < 0.05)
                            Change within noise threshold.
    Found 9 outliers among 100 measurements (9.00%)
      1 (1.00%) low mild
      5 (5.00%) high mild
      3 (3.00%) high severe
    
    aggregate_query_no_group_by_min_max_f64
                            time:   [579.93 µs 580.59 µs 581.27 µs]
                            change: [-3.8985% -3.2219% -2.6198%] (p = 0.00 < 0.05)
                            Performance has improved.
    Found 9 outliers among 100 measurements (9.00%)
      1 (1.00%) low severe
      3 (3.00%) low mild
      1 (1.00%) high mild
      4 (4.00%) high severe
    
    aggregate_query_no_group_by_count_distinct_wide
                            time:   [2.4610 ms 2.4801 ms 2.4990 ms]
                            change: [-2.9300% -1.8414% -0.7493%] (p = 0.00 < 0.05)
                            Change within noise threshold.
    
    Benchmarking aggregate_query_no_group_by_count_distinct_narrow: Warming up for 3.0000 s
    Warning: Unable to complete 100 samples in 5.0s. You may wish to increase target time to 8.4s, enable flat sampling, or reduce sample count to 50.
    aggregate_query_no_group_by_count_distinct_narrow
                            time:   [1.6578 ms 1.6661 ms 1.6743 ms]
                            change: [-4.5391% -3.5033% -2.5050%] (p = 0.00 < 0.05)
                            Performance has improved.
    Found 7 outliers among 100 measurements (7.00%)
      1 (1.00%) low severe
      2 (2.00%) low mild
      2 (2.00%) high mild
      2 (2.00%) high severe
    
    aggregate_query_group_by
                            time:   [2.1767 ms 2.2045 ms 2.2486 ms]
                            change: [-4.1048% -2.5858% -0.3237%] (p = 0.00 < 0.05)
                            Change within noise threshold.
    Found 1 outliers among 100 measurements (1.00%)
      1 (1.00%) high severe
    
    Benchmarking aggregate_query_group_by_with_filter: Warming up for 3.0000 s
    Warning: Unable to complete 100 samples in 5.0s. You may wish to increase target time to 5.5s, enable flat sampling, or reduce sample count to 60.
    aggregate_query_group_by_with_filter
                            time:   [1.0916 ms 1.0927 ms 1.0941 ms]
                            change: [-0.8524% -0.4230% -0.0724%] (p = 0.02 < 0.05)
                            Change within noise threshold.
    Found 9 outliers among 100 measurements (9.00%)
      2 (2.00%) low severe
      1 (1.00%) low mild
      4 (4.00%) high mild
      2 (2.00%) high severe
    
    aggregate_query_group_by_u64 15 12
                            time:   [2.2108 ms 2.2238 ms 2.2368 ms]
                            change: [-4.2142% -3.2743% -2.3523%] (p = 0.00 < 0.05)
                            Performance has improved.
    
    Benchmarking aggregate_query_group_by_with_filter_u64 15 12: Warming up for 3.0000 s
    Warning: Unable to complete 100 samples in 5.0s. You may wish to increase target time to 5.5s, enable flat sampling, or reduce sample count to 60.
    aggregate_query_group_by_with_filter_u64 15 12
                            time:   [1.0922 ms 1.0931 ms 1.0940 ms]
                            change: [-0.6872% -0.3192% +0.1193%] (p = 0.12 > 0.05)
                            No change in performance detected.
    Found 7 outliers among 100 measurements (7.00%)
      3 (3.00%) low mild
      4 (4.00%) high severe
    
    aggregate_query_group_by_u64_multiple_keys
                            time:   [14.714 ms 15.023 ms 15.344 ms]
                            change: [-5.8337% -2.7471% +0.2798%] (p = 0.09 > 0.05)
                            No change in performance detected.
    
    aggregate_query_approx_percentile_cont_on_u64
                            time:   [3.7776 ms 3.8049 ms 3.8329 ms]
                            change: [-4.4977% -3.4230% -2.3282%] (p = 0.00 < 0.05)
                            Performance has improved.
    Found 2 outliers among 100 measurements (2.00%)
      2 (2.00%) high mild
    
    aggregate_query_approx_percentile_cont_on_f32
                            time:   [3.1769 ms 3.1997 ms 3.2230 ms]
                            change: [-4.4664% -3.2597% -2.0955%] (p = 0.00 < 0.05)
                            Performance has improved.
    Found 1 outliers among 100 measurements (1.00%)
      1 (1.00%) high mild
    ```
    
    I think the mild improvements are either flux or due to the somewhat
    manual memory allocation pattern.
    
    * refactor: simplify memory accounting
    
    * refactor: de-couple memory allocation
---
 datafusion/core/src/execution/memory_manager.rs    |   8 +-
 .../core/src/physical_plan/aggregates/mod.rs       |  66 +++-
 .../core/src/physical_plan/aggregates/row_hash.rs  | 355 ++++++++++++++++-----
 3 files changed, 343 insertions(+), 86 deletions(-)

diff --git a/datafusion/core/src/execution/memory_manager.rs b/datafusion/core/src/execution/memory_manager.rs
index 48d4ca3c3..e7148b066 100644
--- a/datafusion/core/src/execution/memory_manager.rs
+++ b/datafusion/core/src/execution/memory_manager.rs
@@ -178,10 +178,8 @@ pub trait MemoryConsumer: Send + Sync {
             self.id(),
         );
 
-        let can_grow_directly = self
-            .memory_manager()
-            .can_grow_directly(required, current)
-            .await;
+        let can_grow_directly =
+            self.memory_manager().can_grow_directly(required, current);
         if !can_grow_directly {
             debug!(
                 "Failed to grow memory of {} directly from consumer {}, spilling first ...",
@@ -334,7 +332,7 @@ impl MemoryManager {
     }
 
     /// Grow memory attempt from a consumer, return if we could grant that much to it
-    async fn can_grow_directly(&self, required: usize, current: usize) -> bool {
+    fn can_grow_directly(&self, required: usize, current: usize) -> bool {
         let num_rqt = self.requesters.lock().len();
         let mut rqt_current_used = self.requesters_total.lock();
         let mut rqt_max = self.max_mem_for_requesters();
diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs
index 43e75e352..6ce58592d 100644
--- a/datafusion/core/src/physical_plan/aggregates/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/mod.rs
@@ -348,7 +348,7 @@ impl ExecutionPlan for AggregateExec {
         context: Arc<TaskContext>,
     ) -> Result<SendableRecordBatchStream> {
         let batch_size = context.session_config().batch_size();
-        let input = self.input.execute(partition, context)?;
+        let input = self.input.execute(partition, Arc::clone(&context))?;
 
         let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
 
@@ -369,6 +369,8 @@ impl ExecutionPlan for AggregateExec {
                 input,
                 baseline_metrics,
                 batch_size,
+                context,
+                partition,
             )?))
         } else {
             Ok(Box::pin(GroupedHashAggregateStream::new(
@@ -689,7 +691,8 @@ fn evaluate_group_by(
 
 #[cfg(test)]
 mod tests {
-    use crate::execution::context::TaskContext;
+    use crate::execution::context::{SessionConfig, TaskContext};
+    use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
     use crate::from_slice::FromSlice;
     use crate::physical_plan::aggregates::{
         AggregateExec, AggregateMode, PhysicalGroupBy,
@@ -700,7 +703,7 @@ mod tests {
     use crate::{assert_batches_sorted_eq, physical_plan::common};
     use arrow::array::{Float64Array, UInt32Array};
     use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
-    use arrow::error::Result as ArrowResult;
+    use arrow::error::{ArrowError, Result as ArrowResult};
     use arrow::record_batch::RecordBatch;
     use datafusion_common::{DataFusionError, Result, ScalarValue};
     use datafusion_physical_expr::expressions::{lit, Count};
@@ -1081,6 +1084,63 @@ mod tests {
         check_grouping_sets(input).await
     }
 
+    #[tokio::test]
+    async fn test_oom() -> Result<()> {
+        let input: Arc<dyn ExecutionPlan> =
+            Arc::new(TestYieldingExec { yield_first: true });
+        let input_schema = input.schema();
+
+        let session_ctx = SessionContext::with_config_rt(
+            SessionConfig::default(),
+            Arc::new(
+                RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0))
+                    .unwrap(),
+            ),
+        );
+        let task_ctx = session_ctx.task_ctx();
+
+        let groups = PhysicalGroupBy {
+            expr: vec![(col("a", &input_schema)?, "a".to_string())],
+            null_expr: vec![],
+            groups: vec![vec![false]],
+        };
+
+        let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
+            col("b", &input_schema)?,
+            "AVG(b)".to_string(),
+            DataType::Float64,
+        ))];
+
+        let partial_aggregate = Arc::new(AggregateExec::try_new(
+            AggregateMode::Partial,
+            groups,
+            aggregates,
+            input,
+            input_schema.clone(),
+        )?);
+
+        let err = common::collect(partial_aggregate.execute(0, task_ctx.clone())?)
+            .await
+            .unwrap_err();
+
+        // error root cause traversal is a bit complicated, see #4172.
+        if let DataFusionError::ArrowError(ArrowError::ExternalError(err)) = err {
+            if let Some(err) = err.downcast_ref::<DataFusionError>() {
+                assert!(
+                    matches!(err, DataFusionError::ResourcesExhausted(_)),
+                    "Wrong inner error type: {}",
+                    err,
+                );
+            } else {
+                panic!("Wrong arrow error type: {err}")
+            }
+        } else {
+            panic!("Wrong outer error type: {err}")
+        }
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn test_drop_cancel_without_groups() -> Result<()> {
         let session_ctx = SessionContext::new();
diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
index b185ec1ec..740e8990e 100644
--- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
@@ -22,12 +22,14 @@ use std::task::{Context, Poll};
 use std::vec;
 
 use ahash::RandomState;
-use futures::{
-    ready,
-    stream::{Stream, StreamExt},
-};
+use async_trait::async_trait;
+use futures::stream::BoxStream;
+use futures::stream::{Stream, StreamExt};
 
 use crate::error::Result;
+use crate::execution::context::TaskContext;
+use crate::execution::memory_manager::ConsumerType;
+use crate::execution::{MemoryConsumer, MemoryConsumerId, MemoryManager};
 use crate::physical_plan::aggregates::{
     evaluate_group_by, evaluate_many, group_schema, AccumulatorItemV2, AggregateMode,
     PhysicalGroupBy,
@@ -45,13 +47,13 @@ use arrow::{
     error::{ArrowError, Result as ArrowResult},
 };
 use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
-use datafusion_common::ScalarValue;
+use datafusion_common::{DataFusionError, ScalarValue};
 use datafusion_row::accessor::RowAccessor;
 use datafusion_row::layout::RowLayout;
 use datafusion_row::reader::{read_row, RowReader};
 use datafusion_row::writer::{write_row, RowWriter};
 use datafusion_row::{MutableRecordBatch, RowType};
-use hashbrown::raw::RawTable;
+use hashbrown::raw::{Bucket, RawTable};
 
 /// Grouping aggregate with row-format aggregation states inside.
 ///
@@ -70,6 +72,16 @@ use hashbrown::raw::RawTable;
 /// [Compact]: datafusion_row::layout::RowType::Compact
 /// [WordAligned]: datafusion_row::layout::RowType::WordAligned
 pub(crate) struct GroupedHashAggregateStreamV2 {
+    stream: BoxStream<'static, ArrowResult<RecordBatch>>,
+    schema: SchemaRef,
+}
+
+/// Actual implementation of [`GroupedHashAggregateStreamV2`].
+///
+/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem
+/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with
+/// [`futures::stream::unfold`]. The latter requires a state object, which is [`GroupedHashAggregateStreamV2Inner`].
+struct GroupedHashAggregateStreamV2Inner {
     schema: SchemaRef,
     input: SendableRecordBatchStream,
     mode: AggregateMode,
@@ -102,6 +114,7 @@ fn aggr_state_schema(aggr_expr: &[Arc<dyn AggregateExpr>]) -> Result<SchemaRef>
 
 impl GroupedHashAggregateStreamV2 {
     /// Create a new GroupedRowHashAggregateStream
+    #[allow(clippy::too_many_arguments)]
     pub fn new(
         mode: AggregateMode,
         schema: SchemaRef,
@@ -110,6 +123,8 @@ impl GroupedHashAggregateStreamV2 {
         input: SendableRecordBatchStream,
         baseline_metrics: BaselineMetrics,
         batch_size: usize,
+        context: Arc<TaskContext>,
+        partition: usize,
     ) -> Result<Self> {
         let timer = baseline_metrics.elapsed_compute().timer();
 
@@ -125,10 +140,24 @@ impl GroupedHashAggregateStreamV2 {
         let aggr_schema = aggr_state_schema(&aggr_expr)?;
 
         let aggr_layout = Arc::new(RowLayout::new(&aggr_schema, RowType::WordAligned));
+
+        let aggr_state = AggregationState {
+            memory_consumer: AggregationStateMemoryConsumer {
+                id: MemoryConsumerId::new(partition),
+                memory_manager: Arc::clone(&context.runtime_env().memory_manager),
+                used: 0,
+            },
+            map: RawTable::with_capacity(0),
+            group_states: Vec::with_capacity(0),
+        };
+        context
+            .runtime_env()
+            .register_requester(aggr_state.memory_consumer.id());
+
         timer.done();
 
-        Ok(Self {
-            schema,
+        let inner = GroupedHashAggregateStreamV2Inner {
+            schema: Arc::clone(&schema),
             mode,
             input,
             group_by,
@@ -138,11 +167,87 @@ impl GroupedHashAggregateStreamV2 {
             aggr_layout,
             baseline_metrics,
             aggregate_expressions,
-            aggr_state: Default::default(),
+            aggr_state,
             random_state: Default::default(),
             batch_size,
             row_group_skip_position: 0,
-        })
+        };
+
+        let stream = futures::stream::unfold(inner, |mut this| async move {
+            let elapsed_compute = this.baseline_metrics.elapsed_compute();
+
+            loop {
+                let result: ArrowResult<Option<RecordBatch>> =
+                    match this.input.next().await {
+                        Some(Ok(batch)) => {
+                            let timer = elapsed_compute.timer();
+                            let result = group_aggregate_batch(
+                                &this.mode,
+                                &this.random_state,
+                                &this.group_by,
+                                &mut this.accumulators,
+                                &this.group_schema,
+                                this.aggr_layout.clone(),
+                                batch,
+                                &mut this.aggr_state,
+                                &this.aggregate_expressions,
+                            );
+
+                            timer.done();
+
+                            // allocate memory
+                            // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
+                            // overshooting a bit. Also this means we either store the whole record batch or not.
+                            let result = match result {
+                                Ok(allocated) => {
+                                    this.aggr_state.memory_consumer.alloc(allocated).await
+                                }
+                                Err(e) => Err(e),
+                            };
+
+                            match result {
+                                Ok(()) => continue,
+                                Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
+                            }
+                        }
+                        Some(Err(e)) => Err(e),
+                        None => {
+                            let timer = this.baseline_metrics.elapsed_compute().timer();
+                            let result = create_batch_from_map(
+                                &this.mode,
+                                &this.group_schema,
+                                &this.aggr_schema,
+                                this.batch_size,
+                                this.row_group_skip_position,
+                                &mut this.aggr_state,
+                                &mut this.accumulators,
+                                &this.schema,
+                            );
+
+                            timer.done();
+                            result
+                        }
+                    };
+
+                this.row_group_skip_position += this.batch_size;
+                match result {
+                    Ok(Some(result)) => {
+                        return Some((
+                            Ok(result.record_output(&this.baseline_metrics)),
+                            this,
+                        ));
+                    }
+                    Ok(None) => return None,
+                    Err(error) => return Some((Err(error), this)),
+                }
+            }
+        });
+
+        // seems like some consumers call this stream even after it returned `None`, so let's fuse the stream.
+        let stream = stream.fuse();
+        let stream = Box::pin(stream);
+
+        Ok(Self { schema, stream })
     }
 }
 
@@ -154,63 +259,7 @@ impl Stream for GroupedHashAggregateStreamV2 {
         cx: &mut Context<'_>,
     ) -> Poll<Option<Self::Item>> {
         let this = &mut *self;
-
-        let elapsed_compute = this.baseline_metrics.elapsed_compute();
-
-        loop {
-            let result: ArrowResult<Option<RecordBatch>> =
-                match ready!(this.input.poll_next_unpin(cx)) {
-                    Some(Ok(batch)) => {
-                        let timer = elapsed_compute.timer();
-                        let result = group_aggregate_batch(
-                            &this.mode,
-                            &this.random_state,
-                            &this.group_by,
-                            &mut this.accumulators,
-                            &this.group_schema,
-                            this.aggr_layout.clone(),
-                            batch,
-                            &mut this.aggr_state,
-                            &this.aggregate_expressions,
-                        );
-
-                        timer.done();
-
-                        match result {
-                            Ok(_) => continue,
-                            Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
-                        }
-                    }
-                    Some(Err(e)) => Err(e),
-                    None => {
-                        let timer = this.baseline_metrics.elapsed_compute().timer();
-                        let result = create_batch_from_map(
-                            &this.mode,
-                            &this.group_schema,
-                            &this.aggr_schema,
-                            this.batch_size,
-                            this.row_group_skip_position,
-                            &mut this.aggr_state,
-                            &mut this.accumulators,
-                            &this.schema,
-                        );
-
-                        timer.done();
-                        result
-                    }
-                };
-
-            this.row_group_skip_position += this.batch_size;
-            match result {
-                Ok(Some(result)) => {
-                    return Poll::Ready(Some(Ok(
-                        result.record_output(&this.baseline_metrics)
-                    )))
-                }
-                Ok(None) => return Poll::Ready(None),
-                Err(error) => return Poll::Ready(Some(Err(error))),
-            }
-        }
+        this.stream.poll_next_unpin(cx)
     }
 }
 
@@ -220,6 +269,10 @@ impl RecordBatchStream for GroupedHashAggregateStreamV2 {
     }
 }
 
+/// Perform group-by aggregation for the given [`RecordBatch`].
+///
+/// If successfull, this returns the additional number of bytes that were allocated during this process.
+///
 /// TODO: Make this a member function of [`GroupedHashAggregateStreamV2`]
 #[allow(clippy::too_many_arguments)]
 fn group_aggregate_batch(
@@ -232,10 +285,15 @@ fn group_aggregate_batch(
     batch: RecordBatch,
     aggr_state: &mut AggregationState,
     aggregate_expressions: &[Vec<Arc<dyn PhysicalExpr>>],
-) -> Result<()> {
+) -> Result<usize> {
     // evaluate the grouping expressions
     let grouping_by_values = evaluate_group_by(grouping_set, &batch)?;
 
+    let AggregationState {
+        map, group_states, ..
+    } = aggr_state;
+    let mut allocated = 0usize;
+
     for group_values in grouping_by_values {
         let group_rows: Vec<Vec<u8>> = create_group_rows(group_values, group_schema);
 
@@ -256,8 +314,6 @@ fn group_aggregate_batch(
         create_row_hashes(&group_rows, random_state, &mut batch_hashes)?;
 
         for (row, hash) in batch_hashes.into_iter().enumerate() {
-            let AggregationState { map, group_states } = aggr_state;
-
             let entry = map.get_mut(hash, |(_hash, group_idx)| {
                 // verify that a group that we are inserting with hash is
                 // actually the same key value as the group in
@@ -270,11 +326,15 @@ fn group_aggregate_batch(
                 // Existing entry for this group value
                 Some((_hash, group_idx)) => {
                     let group_state = &mut group_states[*group_idx];
+
                     // 1.3
                     if group_state.indices.is_empty() {
                         groups_with_rows.push(*group_idx);
                     };
-                    group_state.indices.push(row as u32); // remember this row
+
+                    group_state
+                        .indices
+                        .push_accounted(row as u32, &mut allocated); // remember this row
                 }
                 //  1.2 Need to create new entry
                 None => {
@@ -285,11 +345,25 @@ fn group_aggregate_batch(
                         indices: vec![row as u32], // 1.3
                     };
                     let group_idx = group_states.len();
-                    group_states.push(group_state);
-                    groups_with_rows.push(group_idx);
+
+                    // NOTE: do NOT include the `RowGroupState` struct size in here because this is captured by
+                    // `group_states` (see allocation down below)
+                    allocated += (std::mem::size_of::<u8>()
+                        * group_state.group_by_values.capacity())
+                        + (std::mem::size_of::<u8>()
+                            * group_state.aggregation_buffer.capacity())
+                        + (std::mem::size_of::<u32>() * group_state.indices.capacity());
 
                     // for hasher function, use precomputed hash value
-                    map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash);
+                    map.insert_accounted(
+                        (hash, group_idx),
+                        |(hash, _group_index)| *hash,
+                        &mut allocated,
+                    );
+
+                    group_states.push_accounted(group_state, &mut allocated);
+
+                    groups_with_rows.push(group_idx);
                 }
             };
         }
@@ -299,7 +373,7 @@ fn group_aggregate_batch(
         let mut offsets = vec![0];
         let mut offset_so_far = 0;
         for group_idx in groups_with_rows.iter() {
-            let indices = &aggr_state.group_states[*group_idx].indices;
+            let indices = &group_states[*group_idx].indices;
             batch_indices.append_slice(indices);
             offset_so_far += indices.len();
             offsets.push(offset_so_far);
@@ -334,7 +408,7 @@ fn group_aggregate_batch(
             .iter()
             .zip(offsets.windows(2))
             .try_for_each(|(group_idx, offsets)| {
-                let group_state = &mut aggr_state.group_states[*group_idx];
+                let group_state = &mut group_states[*group_idx];
                 // 2.2
                 accumulators
                     .iter_mut()
@@ -374,7 +448,7 @@ fn group_aggregate_batch(
             })?;
     }
 
-    Ok(())
+    Ok(allocated)
 }
 
 /// The state that is built for each output group.
@@ -392,8 +466,9 @@ struct RowGroupState {
 }
 
 /// The state of all the groups
-#[derive(Default)]
 struct AggregationState {
+    memory_consumer: AggregationStateMemoryConsumer,
+
     /// Logically maps group values to an index in `group_states`
     ///
     /// Uses the raw API of hashbrown to avoid actually storing the
@@ -418,6 +493,130 @@ impl std::fmt::Debug for AggregationState {
     }
 }
 
+/// Accounting data structure for memory usage.
+struct AggregationStateMemoryConsumer {
+    /// Consumer ID.
+    id: MemoryConsumerId,
+
+    /// Linked memory manager.
+    memory_manager: Arc<MemoryManager>,
+
+    /// Currently used size in bytes.
+    used: usize,
+}
+
+#[async_trait]
+impl MemoryConsumer for AggregationStateMemoryConsumer {
+    fn name(&self) -> String {
+        "AggregationState".to_owned()
+    }
+
+    fn id(&self) -> &crate::execution::MemoryConsumerId {
+        &self.id
+    }
+
+    fn memory_manager(&self) -> Arc<MemoryManager> {
+        Arc::clone(&self.memory_manager)
+    }
+
+    fn type_(&self) -> &ConsumerType {
+        &ConsumerType::Tracking
+    }
+
+    async fn spill(&self) -> Result<usize> {
+        Err(DataFusionError::ResourcesExhausted(
+            "Cannot spill AggregationState".to_owned(),
+        ))
+    }
+
+    fn mem_used(&self) -> usize {
+        self.used
+    }
+}
+
+impl AggregationStateMemoryConsumer {
+    async fn alloc(&mut self, bytes: usize) -> Result<()> {
+        self.try_grow(bytes).await?;
+        self.used = self.used.checked_add(bytes).expect("overflow");
+        Ok(())
+    }
+}
+
+impl Drop for AggregationStateMemoryConsumer {
+    fn drop(&mut self) {
+        self.memory_manager
+            .drop_consumer(self.id(), self.mem_used());
+    }
+}
+
+trait VecAllocExt {
+    type T;
+
+    fn push_accounted(&mut self, x: Self::T, accounting: &mut usize);
+}
+
+impl<T> VecAllocExt for Vec<T> {
+    type T = T;
+
+    fn push_accounted(&mut self, x: Self::T, accounting: &mut usize) {
+        if self.capacity() == self.len() {
+            // allocate more
+
+            // growth factor: 2, but at least 2 elements
+            let bump_elements = (self.capacity() * 2).max(2);
+            let bump_size = std::mem::size_of::<u32>() * bump_elements;
+            self.reserve(bump_elements);
+            *accounting = (*accounting).checked_add(bump_size).expect("overflow");
+        }
+
+        self.push(x);
+    }
+}
+
+trait RawTableAllocExt {
+    type T;
+
+    fn insert_accounted(
+        &mut self,
+        x: Self::T,
+        hasher: impl Fn(&Self::T) -> u64,
+        accounting: &mut usize,
+    ) -> Bucket<Self::T>;
+}
+
+impl<T> RawTableAllocExt for RawTable<T> {
+    type T = T;
+
+    fn insert_accounted(
+        &mut self,
+        x: Self::T,
+        hasher: impl Fn(&Self::T) -> u64,
+        accounting: &mut usize,
+    ) -> Bucket<Self::T> {
+        let hash = hasher(&x);
+
+        match self.try_insert_no_grow(hash, x) {
+            Ok(bucket) => bucket,
+            Err(x) => {
+                // need to request more memory
+
+                let bump_elements = (self.capacity() * 2).max(16);
+                let bump_size = bump_elements * std::mem::size_of::<T>();
+                *accounting = (*accounting).checked_add(bump_size).expect("overflow");
+
+                self.reserve(bump_elements, hasher);
+
+                // still need to insert the element since first try failed
+                // Note: cannot use `.expect` here because `T` may not implement `Debug`
+                match self.try_insert_no_grow(hash, x) {
+                    Ok(bucket) => bucket,
+                    Err(_) => panic!("just grew the container"),
+                }
+            }
+        }
+    }
+}
+
 /// Create grouping rows
 fn create_group_rows(arrays: Vec<ArrayRef>, schema: &Schema) -> Vec<Vec<u8>> {
     let mut writer = RowWriter::new(schema, RowType::Compact);