You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/07 09:19:32 UTC

[arrow-datafusion] branch master updated: fix: account for memory in `RepartitionExec` (#4820)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 83c102698 fix: account for memory in `RepartitionExec` (#4820)
83c102698 is described below

commit 83c102698945e0984f8fa53e75b04478e49e5242
Author: Marco Neumann <ma...@crepererum.net>
AuthorDate: Sat Jan 7 10:19:25 2023 +0100

    fix: account for memory in `RepartitionExec` (#4820)
    
    * refactor: explicit loop instead of (tail) recursion
    
    * test: simplify
    
    * fix: account for memory in `RepartitionExec`
    
    Fixes #4816.
    
    * fix: sorting memory limit test
---
 .../core/src/physical_plan/aggregates/mod.rs       |  19 +--
 datafusion/core/src/physical_plan/repartition.rs   | 127 +++++++++++++++++----
 datafusion/core/tests/memory_limit.rs              |   6 +-
 3 files changed, 116 insertions(+), 36 deletions(-)

diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs
index 07f3563bb..8044f4c15 100644
--- a/datafusion/core/src/physical_plan/aggregates/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/mod.rs
@@ -746,7 +746,7 @@ mod tests {
     use crate::{assert_batches_sorted_eq, physical_plan::common};
     use arrow::array::{Float64Array, UInt32Array};
     use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
-    use arrow::error::{ArrowError, Result as ArrowResult};
+    use arrow::error::Result as ArrowResult;
     use arrow::record_batch::RecordBatch;
     use datafusion_common::{DataFusionError, Result, ScalarValue};
     use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count, Median};
@@ -1207,18 +1207,11 @@ mod tests {
             let err = common::collect(stream).await.unwrap_err();
 
             // error root cause traversal is a bit complicated, see #4172.
-            if let DataFusionError::ArrowError(ArrowError::ExternalError(err)) = err {
-                if let Some(err) = err.downcast_ref::<DataFusionError>() {
-                    assert!(
-                        matches!(err, DataFusionError::ResourcesExhausted(_)),
-                        "Wrong inner error type: {err}",
-                    );
-                } else {
-                    panic!("Wrong arrow error type: {err}")
-                }
-            } else {
-                panic!("Wrong outer error type: {err}")
-            }
+            let err = err.find_root();
+            assert!(
+                matches!(err, DataFusionError::ResourcesExhausted(_)),
+                "Wrong error type: {err}",
+            );
         }
 
         Ok(())
diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs
index ee2e976ce..451b0fba4 100644
--- a/datafusion/core/src/physical_plan/repartition.rs
+++ b/datafusion/core/src/physical_plan/repartition.rs
@@ -24,6 +24,7 @@ use std::task::{Context, Poll};
 use std::{any::Any, vec};
 
 use crate::error::{DataFusionError, Result};
+use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation};
 use crate::physical_plan::hash_utils::create_hashes;
 use crate::physical_plan::{
     DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics,
@@ -50,14 +51,21 @@ use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
 use tokio::task::JoinHandle;
 
 type MaybeBatch = Option<ArrowResult<RecordBatch>>;
+type SharedMemoryReservation = Arc<Mutex<MemoryReservation>>;
 
 /// Inner state of [`RepartitionExec`].
 #[derive(Debug)]
 struct RepartitionExecState {
     /// Channels for sending batches from input partitions to output partitions.
     /// Key is the partition number.
-    channels:
-        HashMap<usize, (UnboundedSender<MaybeBatch>, UnboundedReceiver<MaybeBatch>)>,
+    channels: HashMap<
+        usize,
+        (
+            UnboundedSender<MaybeBatch>,
+            UnboundedReceiver<MaybeBatch>,
+            SharedMemoryReservation,
+        ),
+    >,
 
     /// Helper that ensures that that background job is killed once it is no longer needed.
     abort_helper: Arc<AbortOnDropMany<()>>,
@@ -338,7 +346,13 @@ impl ExecutionPlan for RepartitionExec {
                 // for this would be to add spill-to-disk capabilities.
                 let (sender, receiver) =
                     mpsc::unbounded_channel::<Option<ArrowResult<RecordBatch>>>();
-                state.channels.insert(partition, (sender, receiver));
+                let reservation = Arc::new(Mutex::new(
+                    MemoryConsumer::new(format!("RepartitionExec[{partition}]"))
+                        .register(context.memory_pool()),
+                ));
+                state
+                    .channels
+                    .insert(partition, (sender, receiver, reservation));
             }
 
             // launch one async task per *input* partition
@@ -347,7 +361,9 @@ impl ExecutionPlan for RepartitionExec {
                 let txs: HashMap<_, _> = state
                     .channels
                     .iter()
-                    .map(|(partition, (tx, _rx))| (*partition, tx.clone()))
+                    .map(|(partition, (tx, _rx, reservation))| {
+                        (*partition, (tx.clone(), Arc::clone(reservation)))
+                    })
                     .collect();
 
                 let r_metrics = RepartitionMetrics::new(i, partition, &self.metrics);
@@ -366,7 +382,9 @@ impl ExecutionPlan for RepartitionExec {
                 // (and pass along any errors, including panic!s)
                 let join_handle = tokio::spawn(Self::wait_for_task(
                     AbortOnDropSingle::new(input_task),
-                    txs,
+                    txs.into_iter()
+                        .map(|(partition, (tx, _reservation))| (partition, tx))
+                        .collect(),
                 ));
                 join_handles.push(join_handle);
             }
@@ -381,14 +399,17 @@ impl ExecutionPlan for RepartitionExec {
 
         // now return stream for the specified *output* partition which will
         // read from the channel
+        let (_tx, rx, reservation) = state
+            .channels
+            .remove(&partition)
+            .expect("partition not used yet");
         Ok(Box::pin(RepartitionStream {
             num_input_partitions,
             num_input_partitions_processed: 0,
             schema: self.input.schema(),
-            input: UnboundedReceiverStream::new(
-                state.channels.remove(&partition).unwrap().1,
-            ),
+            input: UnboundedReceiverStream::new(rx),
             drop_helper: Arc::clone(&state.abort_helper),
+            reservation,
         }))
     }
 
@@ -439,7 +460,7 @@ impl RepartitionExec {
     async fn pull_from_input(
         input: Arc<dyn ExecutionPlan>,
         i: usize,
-        mut txs: HashMap<usize, UnboundedSender<Option<ArrowResult<RecordBatch>>>>,
+        mut txs: HashMap<usize, (UnboundedSender<MaybeBatch>, SharedMemoryReservation)>,
         partitioning: Partitioning,
         r_metrics: RepartitionMetrics,
         context: Arc<TaskContext>,
@@ -467,11 +488,16 @@ impl RepartitionExec {
             };
 
             partitioner.partition(batch, |partition, partitioned| {
+                let size = partitioned.get_array_memory_size();
+
                 let timer = r_metrics.send_time.timer();
                 // if there is still a receiver, send to it
-                if let Some(tx) = txs.get_mut(&partition) {
+                if let Some((tx, reservation)) = txs.get_mut(&partition) {
+                    reservation.lock().try_grow(size)?;
+
                     if tx.send(Some(Ok(partitioned))).is_err() {
                         // If the other end has hung up, it was an early shutdown (e.g. LIMIT)
+                        reservation.lock().shrink(size);
                         txs.remove(&partition);
                     }
                 }
@@ -546,6 +572,9 @@ struct RepartitionStream {
     /// Handle to ensure background tasks are killed when no longer needed.
     #[allow(dead_code)]
     drop_helper: Arc<AbortOnDropMany<()>>,
+
+    /// Memory reservation.
+    reservation: SharedMemoryReservation,
 }
 
 impl Stream for RepartitionStream {
@@ -555,20 +584,35 @@ impl Stream for RepartitionStream {
         mut self: Pin<&mut Self>,
         cx: &mut Context<'_>,
     ) -> Poll<Option<Self::Item>> {
-        match self.input.poll_next_unpin(cx) {
-            Poll::Ready(Some(Some(v))) => Poll::Ready(Some(v)),
-            Poll::Ready(Some(None)) => {
-                self.num_input_partitions_processed += 1;
-                if self.num_input_partitions == self.num_input_partitions_processed {
-                    // all input partitions have finished sending batches
-                    Poll::Ready(None)
-                } else {
-                    // other partitions still have data to send
-                    self.poll_next(cx)
+        loop {
+            match self.input.poll_next_unpin(cx) {
+                Poll::Ready(Some(Some(v))) => {
+                    if let Ok(batch) = &v {
+                        self.reservation
+                            .lock()
+                            .shrink(batch.get_array_memory_size());
+                    }
+
+                    return Poll::Ready(Some(v));
+                }
+                Poll::Ready(Some(None)) => {
+                    self.num_input_partitions_processed += 1;
+
+                    if self.num_input_partitions == self.num_input_partitions_processed {
+                        // all input partitions have finished sending batches
+                        return Poll::Ready(None);
+                    } else {
+                        // other partitions still have data to send
+                        continue;
+                    }
+                }
+                Poll::Ready(None) => {
+                    return Poll::Ready(None);
+                }
+                Poll::Pending => {
+                    return Poll::Pending;
                 }
             }
-            Poll::Ready(None) => Poll::Ready(None),
-            Poll::Pending => Poll::Pending,
         }
     }
 }
@@ -583,6 +627,8 @@ impl RecordBatchStream for RepartitionStream {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::execution::context::SessionConfig;
+    use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
     use crate::from_slice::FromSlice;
     use crate::prelude::SessionContext;
     use crate::test::create_vec_batches;
@@ -1078,4 +1124,41 @@ mod tests {
         assert!(batch0.is_empty() || batch1.is_empty());
         Ok(())
     }
+
+    #[tokio::test]
+    async fn oom() -> Result<()> {
+        // define input partitions
+        let schema = test_schema();
+        let partition = create_vec_batches(&schema, 50);
+        let input_partitions = vec![partition];
+        let partitioning = Partitioning::RoundRobinBatch(4);
+
+        // setup up context
+        let session_ctx = SessionContext::with_config_rt(
+            SessionConfig::default(),
+            Arc::new(
+                RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0))
+                    .unwrap(),
+            ),
+        );
+        let task_ctx = session_ctx.task_ctx();
+
+        // create physical plan
+        let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?;
+        let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?;
+
+        // pull partitions
+        for i in 0..exec.partitioning.partition_count() {
+            let mut stream = exec.execute(i, task_ctx.clone())?;
+            let err =
+                DataFusionError::ArrowError(stream.next().await.unwrap().unwrap_err());
+            let err = err.find_root();
+            assert!(
+                matches!(err, DataFusionError::ResourcesExhausted(_)),
+                "Wrong error type: {err}",
+            );
+        }
+
+        Ok(())
+    }
 }
diff --git a/datafusion/core/tests/memory_limit.rs b/datafusion/core/tests/memory_limit.rs
index 91d66e884..170f55903 100644
--- a/datafusion/core/tests/memory_limit.rs
+++ b/datafusion/core/tests/memory_limit.rs
@@ -95,7 +95,11 @@ async fn run_limit_test(query: &str, expected_error: &str, memory_limit: usize)
 
     let runtime = RuntimeEnv::new(rt_config).unwrap();
 
-    let ctx = SessionContext::with_config_rt(SessionConfig::new(), Arc::new(runtime));
+    let ctx = SessionContext::with_config_rt(
+        // do NOT re-partition (since RepartitionExec has also has a memory budget which we'll likely hit first)
+        SessionConfig::new().with_target_partitions(1),
+        Arc::new(runtime),
+    );
     ctx.register_table("t", Arc::new(table))
         .expect("registering table");