You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/07 09:19:32 UTC
[arrow-datafusion] branch master updated: fix: account for memory in `RepartitionExec` (#4820)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 83c102698 fix: account for memory in `RepartitionExec` (#4820)
83c102698 is described below
commit 83c102698945e0984f8fa53e75b04478e49e5242
Author: Marco Neumann <ma...@crepererum.net>
AuthorDate: Sat Jan 7 10:19:25 2023 +0100
fix: account for memory in `RepartitionExec` (#4820)
* refactor: explicit loop instead of (tail) recursion
* test: simplify
* fix: account for memory in `RepartitionExec`
Fixes #4816.
* fix: sorting memory limit test
---
.../core/src/physical_plan/aggregates/mod.rs | 19 +--
datafusion/core/src/physical_plan/repartition.rs | 127 +++++++++++++++++----
datafusion/core/tests/memory_limit.rs | 6 +-
3 files changed, 116 insertions(+), 36 deletions(-)
diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs
index 07f3563bb..8044f4c15 100644
--- a/datafusion/core/src/physical_plan/aggregates/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/mod.rs
@@ -746,7 +746,7 @@ mod tests {
use crate::{assert_batches_sorted_eq, physical_plan::common};
use arrow::array::{Float64Array, UInt32Array};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
- use arrow::error::{ArrowError, Result as ArrowResult};
+ use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count, Median};
@@ -1207,18 +1207,11 @@ mod tests {
let err = common::collect(stream).await.unwrap_err();
// error root cause traversal is a bit complicated, see #4172.
- if let DataFusionError::ArrowError(ArrowError::ExternalError(err)) = err {
- if let Some(err) = err.downcast_ref::<DataFusionError>() {
- assert!(
- matches!(err, DataFusionError::ResourcesExhausted(_)),
- "Wrong inner error type: {err}",
- );
- } else {
- panic!("Wrong arrow error type: {err}")
- }
- } else {
- panic!("Wrong outer error type: {err}")
- }
+ let err = err.find_root();
+ assert!(
+ matches!(err, DataFusionError::ResourcesExhausted(_)),
+ "Wrong error type: {err}",
+ );
}
Ok(())
diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs
index ee2e976ce..451b0fba4 100644
--- a/datafusion/core/src/physical_plan/repartition.rs
+++ b/datafusion/core/src/physical_plan/repartition.rs
@@ -24,6 +24,7 @@ use std::task::{Context, Poll};
use std::{any::Any, vec};
use crate::error::{DataFusionError, Result};
+use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation};
use crate::physical_plan::hash_utils::create_hashes;
use crate::physical_plan::{
DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics,
@@ -50,14 +51,21 @@ use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use tokio::task::JoinHandle;
type MaybeBatch = Option<ArrowResult<RecordBatch>>;
+type SharedMemoryReservation = Arc<Mutex<MemoryReservation>>;
/// Inner state of [`RepartitionExec`].
#[derive(Debug)]
struct RepartitionExecState {
/// Channels for sending batches from input partitions to output partitions.
/// Key is the partition number.
- channels:
- HashMap<usize, (UnboundedSender<MaybeBatch>, UnboundedReceiver<MaybeBatch>)>,
+ channels: HashMap<
+ usize,
+ (
+ UnboundedSender<MaybeBatch>,
+ UnboundedReceiver<MaybeBatch>,
+ SharedMemoryReservation,
+ ),
+ >,
/// Helper that ensures that that background job is killed once it is no longer needed.
abort_helper: Arc<AbortOnDropMany<()>>,
@@ -338,7 +346,13 @@ impl ExecutionPlan for RepartitionExec {
// for this would be to add spill-to-disk capabilities.
let (sender, receiver) =
mpsc::unbounded_channel::<Option<ArrowResult<RecordBatch>>>();
- state.channels.insert(partition, (sender, receiver));
+ let reservation = Arc::new(Mutex::new(
+ MemoryConsumer::new(format!("RepartitionExec[{partition}]"))
+ .register(context.memory_pool()),
+ ));
+ state
+ .channels
+ .insert(partition, (sender, receiver, reservation));
}
// launch one async task per *input* partition
@@ -347,7 +361,9 @@ impl ExecutionPlan for RepartitionExec {
let txs: HashMap<_, _> = state
.channels
.iter()
- .map(|(partition, (tx, _rx))| (*partition, tx.clone()))
+ .map(|(partition, (tx, _rx, reservation))| {
+ (*partition, (tx.clone(), Arc::clone(reservation)))
+ })
.collect();
let r_metrics = RepartitionMetrics::new(i, partition, &self.metrics);
@@ -366,7 +382,9 @@ impl ExecutionPlan for RepartitionExec {
// (and pass along any errors, including panic!s)
let join_handle = tokio::spawn(Self::wait_for_task(
AbortOnDropSingle::new(input_task),
- txs,
+ txs.into_iter()
+ .map(|(partition, (tx, _reservation))| (partition, tx))
+ .collect(),
));
join_handles.push(join_handle);
}
@@ -381,14 +399,17 @@ impl ExecutionPlan for RepartitionExec {
// now return stream for the specified *output* partition which will
// read from the channel
+ let (_tx, rx, reservation) = state
+ .channels
+ .remove(&partition)
+ .expect("partition not used yet");
Ok(Box::pin(RepartitionStream {
num_input_partitions,
num_input_partitions_processed: 0,
schema: self.input.schema(),
- input: UnboundedReceiverStream::new(
- state.channels.remove(&partition).unwrap().1,
- ),
+ input: UnboundedReceiverStream::new(rx),
drop_helper: Arc::clone(&state.abort_helper),
+ reservation,
}))
}
@@ -439,7 +460,7 @@ impl RepartitionExec {
async fn pull_from_input(
input: Arc<dyn ExecutionPlan>,
i: usize,
- mut txs: HashMap<usize, UnboundedSender<Option<ArrowResult<RecordBatch>>>>,
+ mut txs: HashMap<usize, (UnboundedSender<MaybeBatch>, SharedMemoryReservation)>,
partitioning: Partitioning,
r_metrics: RepartitionMetrics,
context: Arc<TaskContext>,
@@ -467,11 +488,16 @@ impl RepartitionExec {
};
partitioner.partition(batch, |partition, partitioned| {
+ let size = partitioned.get_array_memory_size();
+
let timer = r_metrics.send_time.timer();
// if there is still a receiver, send to it
- if let Some(tx) = txs.get_mut(&partition) {
+ if let Some((tx, reservation)) = txs.get_mut(&partition) {
+ reservation.lock().try_grow(size)?;
+
if tx.send(Some(Ok(partitioned))).is_err() {
// If the other end has hung up, it was an early shutdown (e.g. LIMIT)
+ reservation.lock().shrink(size);
txs.remove(&partition);
}
}
@@ -546,6 +572,9 @@ struct RepartitionStream {
/// Handle to ensure background tasks are killed when no longer needed.
#[allow(dead_code)]
drop_helper: Arc<AbortOnDropMany<()>>,
+
+ /// Memory reservation.
+ reservation: SharedMemoryReservation,
}
impl Stream for RepartitionStream {
@@ -555,20 +584,35 @@ impl Stream for RepartitionStream {
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
- match self.input.poll_next_unpin(cx) {
- Poll::Ready(Some(Some(v))) => Poll::Ready(Some(v)),
- Poll::Ready(Some(None)) => {
- self.num_input_partitions_processed += 1;
- if self.num_input_partitions == self.num_input_partitions_processed {
- // all input partitions have finished sending batches
- Poll::Ready(None)
- } else {
- // other partitions still have data to send
- self.poll_next(cx)
+ loop {
+ match self.input.poll_next_unpin(cx) {
+ Poll::Ready(Some(Some(v))) => {
+ if let Ok(batch) = &v {
+ self.reservation
+ .lock()
+ .shrink(batch.get_array_memory_size());
+ }
+
+ return Poll::Ready(Some(v));
+ }
+ Poll::Ready(Some(None)) => {
+ self.num_input_partitions_processed += 1;
+
+ if self.num_input_partitions == self.num_input_partitions_processed {
+ // all input partitions have finished sending batches
+ return Poll::Ready(None);
+ } else {
+ // other partitions still have data to send
+ continue;
+ }
+ }
+ Poll::Ready(None) => {
+ return Poll::Ready(None);
+ }
+ Poll::Pending => {
+ return Poll::Pending;
}
}
- Poll::Ready(None) => Poll::Ready(None),
- Poll::Pending => Poll::Pending,
}
}
}
@@ -583,6 +627,8 @@ impl RecordBatchStream for RepartitionStream {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::execution::context::SessionConfig;
+ use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use crate::from_slice::FromSlice;
use crate::prelude::SessionContext;
use crate::test::create_vec_batches;
@@ -1078,4 +1124,41 @@ mod tests {
assert!(batch0.is_empty() || batch1.is_empty());
Ok(())
}
+
+ #[tokio::test]
+ async fn oom() -> Result<()> {
+ // define input partitions
+ let schema = test_schema();
+ let partition = create_vec_batches(&schema, 50);
+ let input_partitions = vec![partition];
+ let partitioning = Partitioning::RoundRobinBatch(4);
+
+ // setup up context
+ let session_ctx = SessionContext::with_config_rt(
+ SessionConfig::default(),
+ Arc::new(
+ RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0))
+ .unwrap(),
+ ),
+ );
+ let task_ctx = session_ctx.task_ctx();
+
+ // create physical plan
+ let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?;
+ let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?;
+
+ // pull partitions
+ for i in 0..exec.partitioning.partition_count() {
+ let mut stream = exec.execute(i, task_ctx.clone())?;
+ let err =
+ DataFusionError::ArrowError(stream.next().await.unwrap().unwrap_err());
+ let err = err.find_root();
+ assert!(
+ matches!(err, DataFusionError::ResourcesExhausted(_)),
+ "Wrong error type: {err}",
+ );
+ }
+
+ Ok(())
+ }
}
diff --git a/datafusion/core/tests/memory_limit.rs b/datafusion/core/tests/memory_limit.rs
index 91d66e884..170f55903 100644
--- a/datafusion/core/tests/memory_limit.rs
+++ b/datafusion/core/tests/memory_limit.rs
@@ -95,7 +95,11 @@ async fn run_limit_test(query: &str, expected_error: &str, memory_limit: usize)
let runtime = RuntimeEnv::new(rt_config).unwrap();
- let ctx = SessionContext::with_config_rt(SessionConfig::new(), Arc::new(runtime));
+ let ctx = SessionContext::with_config_rt(
+ // do NOT re-partition (since RepartitionExec has also has a memory budget which we'll likely hit first)
+ SessionConfig::new().with_target_partitions(1),
+ Arc::new(runtime),
+ );
ctx.register_table("t", Arc::new(table))
.expect("registering table");