You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/28 14:25:57 UTC

[arrow-datafusion] branch master updated: feat: `ResourceExhausted` for memory limit in `GroupedHashAggregateStream` (#4371)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new be1d37671 feat: `ResourceExhausted` for memory limit in `GroupedHashAggregateStream` (#4371)
be1d37671 is described below

commit be1d376710a6b09151929f80107e2bbdd0c6538b
Author: Marco Neumann <ma...@crepererum.net>
AuthorDate: Mon Nov 28 14:25:51 2022 +0000

    feat: `ResourceExhausted` for memory limit in `GroupedHashAggregateStream` (#4371)
    
    * feat: `ResourceExhausted` for memory limit in `GroupedHashAggregateStream`
    
    Closes #3940.
    
    * fix: `ScalarValue` size calculations
    
    * refactor: de-dup code
---
 datafusion/common/src/scalar.rs                    |  38 +++-
 .../core/src/physical_plan/aggregates/hash.rs      | 219 ++++++++++++++-------
 .../core/src/physical_plan/aggregates/mod.rs       | 176 +++++++++++------
 3 files changed, 303 insertions(+), 130 deletions(-)

diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 7f2ea5533..0c3e1d832 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -2297,7 +2297,7 @@ impl ScalarValue {
     /// Estimate size if bytes including `Self`. For values with internal containers such as `String`
     /// includes the allocated size (`capacity`) rather than the current length (`len`)
     pub fn size(&self) -> usize {
-        std::mem::size_of_val(&self)
+        std::mem::size_of_val(self)
             + match self {
                 ScalarValue::Null
                 | ScalarValue::Boolean(_)
@@ -2364,7 +2364,8 @@ impl ScalarValue {
     ///
     /// Includes the size of the [`Vec`] container itself.
     pub fn size_of_vec(vec: &Vec<Self>) -> usize {
-        (std::mem::size_of::<ScalarValue>() * vec.capacity())
+        std::mem::size_of_val(vec)
+            + (std::mem::size_of::<ScalarValue>() * vec.capacity())
             + vec
                 .iter()
                 .map(|sv| sv.size() - std::mem::size_of_val(sv))
@@ -2375,7 +2376,8 @@ impl ScalarValue {
     ///
     /// Includes the size of the [`HashSet`] container itself.
     pub fn size_of_hashset<S>(set: &HashSet<Self, S>) -> usize {
-        (std::mem::size_of::<ScalarValue>() * set.capacity())
+        std::mem::size_of_val(set)
+            + (std::mem::size_of::<ScalarValue>() * set.capacity())
             + set
                 .iter()
                 .map(|sv| sv.size() - std::mem::size_of_val(sv))
@@ -3281,6 +3283,36 @@ mod tests {
         assert_eq!(std::mem::size_of::<ScalarValue>(), 48);
     }
 
+    #[test]
+    fn memory_size() {
+        let sv = ScalarValue::Binary(Some(Vec::with_capacity(10)));
+        assert_eq!(sv.size(), std::mem::size_of::<ScalarValue>() + 10,);
+        let sv_size = sv.size();
+
+        let mut v = Vec::with_capacity(10);
+        // do NOT clone `sv` here because this may shrink the vector capacity
+        v.push(sv);
+        assert_eq!(v.capacity(), 10);
+        assert_eq!(
+            ScalarValue::size_of_vec(&v),
+            std::mem::size_of::<Vec<ScalarValue>>()
+                + (9 * std::mem::size_of::<ScalarValue>())
+                + sv_size,
+        );
+
+        let mut s = HashSet::with_capacity(0);
+        // do NOT clone `sv` here because this may shrink the vector capacity
+        s.insert(v.pop().unwrap());
+        // hashsets may easily grow during insert, so capacity is dynamic
+        let s_capacity = s.capacity();
+        assert_eq!(
+            ScalarValue::size_of_hashset(&s),
+            std::mem::size_of::<HashSet<ScalarValue>>()
+                + ((s_capacity - 1) * std::mem::size_of::<ScalarValue>())
+                + sv_size,
+        );
+    }
+
     #[test]
     fn scalar_eq_array() {
         // Validate that eq_array has the same semantics as ScalarValue::eq
diff --git a/datafusion/core/src/physical_plan/aggregates/hash.rs b/datafusion/core/src/physical_plan/aggregates/hash.rs
index 9487df61a..d3d5a337e 100644
--- a/datafusion/core/src/physical_plan/aggregates/hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/hash.rs
@@ -22,12 +22,16 @@ use std::task::{Context, Poll};
 use std::vec;
 
 use ahash::RandomState;
-use futures::{
-    ready,
-    stream::{Stream, StreamExt},
-};
+use datafusion_expr::Accumulator;
+use futures::stream::BoxStream;
+use futures::stream::{Stream, StreamExt};
 
 use crate::error::Result;
+use crate::execution::context::TaskContext;
+use crate::execution::memory_manager::proxy::{
+    MemoryConsumerProxy, RawTableAllocExt, VecAllocExt,
+};
+use crate::execution::MemoryConsumerId;
 use crate::physical_plan::aggregates::{
     evaluate_group_by, evaluate_many, AccumulatorItem, AggregateMode, PhysicalGroupBy,
 };
@@ -74,6 +78,16 @@ Example: average
 * Finally, `get_value` returns an array with one entry computed from the state
 */
 pub(crate) struct GroupedHashAggregateStream {
+    stream: BoxStream<'static, ArrowResult<RecordBatch>>,
+    schema: SchemaRef,
+}
+
+/// Actual implementation of [`GroupedHashAggregateStream`].
+///
+/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem
+/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with
+/// [`futures::stream::unfold`]. The latter requires a state object, which is [`GroupedHashAggregateStreamV2`].
+struct GroupedHashAggregateStreamInner {
     schema: SchemaRef,
     input: SendableRecordBatchStream,
     mode: AggregateMode,
@@ -90,6 +104,7 @@ pub(crate) struct GroupedHashAggregateStream {
 
 impl GroupedHashAggregateStream {
     /// Create a new GroupedHashAggregateStream
+    #[allow(clippy::too_many_arguments)]
     pub fn new(
         mode: AggregateMode,
         schema: SchemaRef,
@@ -97,6 +112,8 @@ impl GroupedHashAggregateStream {
         aggr_expr: Vec<Arc<dyn AggregateExpr>>,
         input: SendableRecordBatchStream,
         baseline_metrics: BaselineMetrics,
+        context: Arc<TaskContext>,
+        partition: usize,
     ) -> Result<Self> {
         let timer = baseline_metrics.elapsed_compute().timer();
 
@@ -108,18 +125,92 @@ impl GroupedHashAggregateStream {
 
         timer.done();
 
-        Ok(Self {
-            schema,
+        let inner = GroupedHashAggregateStreamInner {
+            schema: Arc::clone(&schema),
             mode,
             input,
             aggr_expr,
             group_by,
             baseline_metrics,
             aggregate_expressions,
-            accumulators: Default::default(),
+            accumulators: Accumulators {
+                memory_consumer: MemoryConsumerProxy::new(
+                    "Accumulators",
+                    MemoryConsumerId::new(partition),
+                    Arc::clone(&context.runtime_env().memory_manager),
+                ),
+                map: RawTable::with_capacity(0),
+                group_states: Vec::with_capacity(0),
+            },
             random_state: Default::default(),
             finished: false,
-        })
+        };
+
+        let stream = futures::stream::unfold(inner, |mut this| async move {
+            if this.finished {
+                return None;
+            }
+
+            let elapsed_compute = this.baseline_metrics.elapsed_compute();
+
+            loop {
+                let result = match this.input.next().await {
+                    Some(Ok(batch)) => {
+                        let timer = elapsed_compute.timer();
+                        let result = group_aggregate_batch(
+                            &this.mode,
+                            &this.random_state,
+                            &this.group_by,
+                            &this.aggr_expr,
+                            batch,
+                            &mut this.accumulators,
+                            &this.aggregate_expressions,
+                        );
+
+                        timer.done();
+
+                        // allocate memory
+                        // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
+                        // overshooting a bit. Also this means we either store the whole record batch or not.
+                        let result = match result {
+                            Ok(allocated) => {
+                                this.accumulators.memory_consumer.alloc(allocated).await
+                            }
+                            Err(e) => Err(e),
+                        };
+
+                        match result {
+                            Ok(()) => continue,
+                            Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
+                        }
+                    }
+                    Some(Err(e)) => Err(e),
+                    None => {
+                        this.finished = true;
+                        let timer = this.baseline_metrics.elapsed_compute().timer();
+                        let result = create_batch_from_map(
+                            &this.mode,
+                            &this.accumulators,
+                            this.group_by.expr.len(),
+                            &this.schema,
+                        )
+                        .record_output(&this.baseline_metrics);
+
+                        timer.done();
+                        result
+                    }
+                };
+
+                this.finished = true;
+                return Some((result, this));
+            }
+        });
+
+        // seems like some consumers call this stream even after it returned `None`, so let's fuse the stream.
+        let stream = stream.fuse();
+        let stream = Box::pin(stream);
+
+        Ok(Self { schema, stream })
     }
 }
 
@@ -131,53 +222,7 @@ impl Stream for GroupedHashAggregateStream {
         cx: &mut Context<'_>,
     ) -> Poll<Option<Self::Item>> {
         let this = &mut *self;
-        if this.finished {
-            return Poll::Ready(None);
-        }
-
-        let elapsed_compute = this.baseline_metrics.elapsed_compute();
-
-        loop {
-            let result = match ready!(this.input.poll_next_unpin(cx)) {
-                Some(Ok(batch)) => {
-                    let timer = elapsed_compute.timer();
-                    let result = group_aggregate_batch(
-                        &this.mode,
-                        &this.random_state,
-                        &this.group_by,
-                        &this.aggr_expr,
-                        batch,
-                        &mut this.accumulators,
-                        &this.aggregate_expressions,
-                    );
-
-                    timer.done();
-
-                    match result {
-                        Ok(_) => continue,
-                        Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
-                    }
-                }
-                Some(Err(e)) => Err(e),
-                None => {
-                    this.finished = true;
-                    let timer = this.baseline_metrics.elapsed_compute().timer();
-                    let result = create_batch_from_map(
-                        &this.mode,
-                        &this.accumulators,
-                        this.group_by.expr.len(),
-                        &this.schema,
-                    )
-                    .record_output(&this.baseline_metrics);
-
-                    timer.done();
-                    result
-                }
-            };
-
-            this.finished = true;
-            return Poll::Ready(Some(result));
-        }
+        this.stream.poll_next_unpin(cx)
     }
 }
 
@@ -187,6 +232,10 @@ impl RecordBatchStream for GroupedHashAggregateStream {
     }
 }
 
+/// Perform group-by aggregation for the given [`RecordBatch`].
+///
+/// If successfull, this returns the additional number of bytes that were allocated during this process.
+///
 /// TODO: Make this a member function of [`GroupedHashAggregateStream`]
 fn group_aggregate_batch(
     mode: &AggregateMode,
@@ -196,7 +245,7 @@ fn group_aggregate_batch(
     batch: RecordBatch,
     accumulators: &mut Accumulators,
     aggregate_expressions: &[Vec<Arc<dyn PhysicalExpr>>],
-) -> Result<()> {
+) -> Result<usize> {
     // evaluate the grouping expressions
     let group_by_values = evaluate_group_by(group_by, &batch)?;
 
@@ -205,6 +254,9 @@ fn group_aggregate_batch(
     // of them anyways, it is more performant to do it while they are together.
     let aggr_input_values = evaluate_many(aggregate_expressions, &batch)?;
 
+    // track memory allocations
+    let mut allocated = 0usize;
+
     for grouping_set_values in group_by_values {
         // 1.1 construct the key from the group values
         // 1.2 construct the mapping key if it does not exist
@@ -218,7 +270,9 @@ fn group_aggregate_batch(
         create_hashes(&grouping_set_values, random_state, &mut batch_hashes)?;
 
         for (row, hash) in batch_hashes.into_iter().enumerate() {
-            let Accumulators { map, group_states } = accumulators;
+            let Accumulators {
+                map, group_states, ..
+            } = accumulators;
 
             let entry = map.get_mut(hash, |(_hash, group_idx)| {
                 // verify that a group that we are inserting with hash is
@@ -239,7 +293,9 @@ fn group_aggregate_batch(
                     if group_state.indices.is_empty() {
                         groups_with_rows.push(*group_idx);
                     };
-                    group_state.indices.push(row as u32); // remember this row
+                    group_state
+                        .indices
+                        .push_accounted(row as u32, &mut allocated); // remember this row
                 }
                 //  1.2 Need to create new entry
                 None => {
@@ -257,12 +313,32 @@ fn group_aggregate_batch(
                         accumulator_set,
                         indices: vec![row as u32], // 1.3
                     };
+                    // NOTE: do NOT include the `GroupState` struct size in here because this is captured by
+                    // `group_states` (see allocation down below)
+                    allocated += group_state
+                        .group_by_values
+                        .iter()
+                        .map(|sv| sv.size())
+                        .sum::<usize>()
+                        + (std::mem::size_of::<Box<dyn Accumulator>>()
+                            * group_state.accumulator_set.capacity())
+                        + group_state
+                            .accumulator_set
+                            .iter()
+                            .map(|accu| accu.size())
+                            .sum::<usize>()
+                        + (std::mem::size_of::<u32>() * group_state.indices.capacity());
+
                     let group_idx = group_states.len();
-                    group_states.push(group_state);
+                    group_states.push_accounted(group_state, &mut allocated);
                     groups_with_rows.push(group_idx);
 
                     // for hasher function, use precomputed hash value
-                    map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash);
+                    map.insert_accounted(
+                        (hash, group_idx),
+                        |(hash, _group_idx)| *hash,
+                        &mut allocated,
+                    );
                 }
             };
         }
@@ -325,12 +401,18 @@ fn group_aggregate_batch(
                                 .collect::<Vec<ArrayRef>>(),
                         )
                     })
-                    .try_for_each(|(accumulator, values)| match mode {
-                        AggregateMode::Partial => accumulator.update_batch(&values),
-                        AggregateMode::FinalPartitioned | AggregateMode::Final => {
-                            // note: the aggregation here is over states, not values, thus the merge
-                            accumulator.merge_batch(&values)
-                        }
+                    .try_for_each(|(accumulator, values)| {
+                        let size_pre = accumulator.size();
+                        let res = match mode {
+                            AggregateMode::Partial => accumulator.update_batch(&values),
+                            AggregateMode::FinalPartitioned | AggregateMode::Final => {
+                                // note: the aggregation here is over states, not values, thus the merge
+                                accumulator.merge_batch(&values)
+                            }
+                        };
+                        let size_post = accumulator.size();
+                        allocated += size_post.saturating_sub(size_pre);
+                        res
                     })
                     // 2.5
                     .and({
@@ -340,7 +422,7 @@ fn group_aggregate_batch(
             })?;
     }
 
-    Ok(())
+    Ok(allocated)
 }
 
 /// The state that is built for each output group.
@@ -358,8 +440,9 @@ struct GroupState {
 }
 
 /// The state of all the groups
-#[derive(Default)]
 struct Accumulators {
+    memory_consumer: MemoryConsumerProxy,
+
     /// Logically maps group values to an index in `group_states`
     ///
     /// Uses the raw API of hashbrown to avoid actually storing the
diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs
index 6ce58592d..312a3263a 100644
--- a/datafusion/core/src/physical_plan/aggregates/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/mod.rs
@@ -150,6 +150,22 @@ impl PhysicalGroupBy {
     }
 }
 
+enum StreamType {
+    AggregateStream(AggregateStream),
+    GroupedHashAggregateStreamV2(GroupedHashAggregateStreamV2),
+    GroupedHashAggregateStream(GroupedHashAggregateStream),
+}
+
+impl From<StreamType> for SendableRecordBatchStream {
+    fn from(stream: StreamType) -> Self {
+        match stream {
+            StreamType::AggregateStream(stream) => Box::pin(stream),
+            StreamType::GroupedHashAggregateStreamV2(stream) => Box::pin(stream),
+            StreamType::GroupedHashAggregateStream(stream) => Box::pin(stream),
+        }
+    }
+}
+
 /// Hash aggregate execution plan
 #[derive(Debug)]
 pub struct AggregateExec {
@@ -261,6 +277,54 @@ impl AggregateExec {
         row_supported(&group_schema, RowType::Compact)
             && accumulator_v2_supported(&self.aggr_expr)
     }
+
+    fn execute_typed(
+        &self,
+        partition: usize,
+        context: Arc<TaskContext>,
+    ) -> Result<StreamType> {
+        let batch_size = context.session_config().batch_size();
+        let input = self.input.execute(partition, Arc::clone(&context))?;
+
+        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
+
+        if self.group_by.expr.is_empty() {
+            Ok(StreamType::AggregateStream(AggregateStream::new(
+                self.mode,
+                self.schema.clone(),
+                self.aggr_expr.clone(),
+                input,
+                baseline_metrics,
+            )?))
+        } else if self.row_aggregate_supported() {
+            Ok(StreamType::GroupedHashAggregateStreamV2(
+                GroupedHashAggregateStreamV2::new(
+                    self.mode,
+                    self.schema.clone(),
+                    self.group_by.clone(),
+                    self.aggr_expr.clone(),
+                    input,
+                    baseline_metrics,
+                    batch_size,
+                    context,
+                    partition,
+                )?,
+            ))
+        } else {
+            Ok(StreamType::GroupedHashAggregateStream(
+                GroupedHashAggregateStream::new(
+                    self.mode,
+                    self.schema.clone(),
+                    self.group_by.clone(),
+                    self.aggr_expr.clone(),
+                    input,
+                    baseline_metrics,
+                    context,
+                    partition,
+                )?,
+            ))
+        }
+    }
 }
 
 impl ExecutionPlan for AggregateExec {
@@ -347,41 +411,8 @@ impl ExecutionPlan for AggregateExec {
         partition: usize,
         context: Arc<TaskContext>,
     ) -> Result<SendableRecordBatchStream> {
-        let batch_size = context.session_config().batch_size();
-        let input = self.input.execute(partition, Arc::clone(&context))?;
-
-        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
-
-        if self.group_by.expr.is_empty() {
-            Ok(Box::pin(AggregateStream::new(
-                self.mode,
-                self.schema.clone(),
-                self.aggr_expr.clone(),
-                input,
-                baseline_metrics,
-            )?))
-        } else if self.row_aggregate_supported() {
-            Ok(Box::pin(GroupedHashAggregateStreamV2::new(
-                self.mode,
-                self.schema.clone(),
-                self.group_by.clone(),
-                self.aggr_expr.clone(),
-                input,
-                baseline_metrics,
-                batch_size,
-                context,
-                partition,
-            )?))
-        } else {
-            Ok(Box::pin(GroupedHashAggregateStream::new(
-                self.mode,
-                self.schema.clone(),
-                self.group_by.clone(),
-                self.aggr_expr.clone(),
-                input,
-                baseline_metrics,
-            )?))
-        }
+        self.execute_typed(partition, context)
+            .map(|stream| stream.into())
     }
 
     fn metrics(&self) -> Option<MetricsSet> {
@@ -706,13 +737,14 @@ mod tests {
     use arrow::error::{ArrowError, Result as ArrowResult};
     use arrow::record_batch::RecordBatch;
     use datafusion_common::{DataFusionError, Result, ScalarValue};
-    use datafusion_physical_expr::expressions::{lit, Count};
+    use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count};
     use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr};
     use futures::{FutureExt, Stream};
     use std::any::Any;
     use std::sync::Arc;
     use std::task::{Context, Poll};
 
+    use super::StreamType;
     use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
     use crate::physical_plan::{
         ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream,
@@ -1105,37 +1137,63 @@ mod tests {
             groups: vec![vec![false]],
         };
 
-        let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
+        // use slow-path in `hash.rs`
+        let aggregates_v1: Vec<Arc<dyn AggregateExpr>> =
+            vec![Arc::new(ApproxDistinct::new(
+                col("a", &input_schema)?,
+                "APPROX_DISTINCT(a)".to_string(),
+                DataType::UInt32,
+            ))];
+
+        // use fast-path in `row_hash.rs`.
+        let aggregates_v2: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
             col("b", &input_schema)?,
             "AVG(b)".to_string(),
             DataType::Float64,
         ))];
 
-        let partial_aggregate = Arc::new(AggregateExec::try_new(
-            AggregateMode::Partial,
-            groups,
-            aggregates,
-            input,
-            input_schema.clone(),
-        )?);
+        for (version, aggregates) in [(1, aggregates_v1), (2, aggregates_v2)] {
+            let partial_aggregate = Arc::new(AggregateExec::try_new(
+                AggregateMode::Partial,
+                groups.clone(),
+                aggregates,
+                input.clone(),
+                input_schema.clone(),
+            )?);
+
+            let stream = partial_aggregate.execute_typed(0, task_ctx.clone())?;
+
+            // ensure that we really got the version we wanted
+            match version {
+                1 => {
+                    assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_)));
+                }
+                2 => {
+                    assert!(matches!(
+                        stream,
+                        StreamType::GroupedHashAggregateStreamV2(_)
+                    ));
+                }
+                _ => panic!("Unknown version: {version}"),
+            }
 
-        let err = common::collect(partial_aggregate.execute(0, task_ctx.clone())?)
-            .await
-            .unwrap_err();
-
-        // error root cause traversal is a bit complicated, see #4172.
-        if let DataFusionError::ArrowError(ArrowError::ExternalError(err)) = err {
-            if let Some(err) = err.downcast_ref::<DataFusionError>() {
-                assert!(
-                    matches!(err, DataFusionError::ResourcesExhausted(_)),
-                    "Wrong inner error type: {}",
-                    err,
-                );
+            let stream: SendableRecordBatchStream = stream.into();
+            let err = common::collect(stream).await.unwrap_err();
+
+            // error root cause traversal is a bit complicated, see #4172.
+            if let DataFusionError::ArrowError(ArrowError::ExternalError(err)) = err {
+                if let Some(err) = err.downcast_ref::<DataFusionError>() {
+                    assert!(
+                        matches!(err, DataFusionError::ResourcesExhausted(_)),
+                        "Wrong inner error type: {}",
+                        err,
+                    );
+                } else {
+                    panic!("Wrong arrow error type: {err}")
+                }
             } else {
-                panic!("Wrong arrow error type: {err}")
+                panic!("Wrong outer error type: {err}")
             }
-        } else {
-            panic!("Wrong outer error type: {err}")
         }
 
         Ok(())