You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/06/06 19:38:28 UTC
[arrow-datafusion] branch main updated: Fix panic propagation in `CoalescePartitions`, consolidates panic propagation into `RecordBatchReceiverStream` (#6507)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 39ee59aeb5 Fix panic propagation in `CoalescePartitions`, consolidates panic propagation into `RecordBatchReceiverStream` (#6507)
39ee59aeb5 is described below
commit 39ee59aeb525739879fde8d3213f56be49fb12da
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Tue Jun 6 15:38:21 2023 -0400
Fix panic propagation in `CoalescePartitions`, consolidates panic propagation into `RecordBatchReceiverStream` (#6507)
* Propagate panics
Another try for fixing #3104.
RepartitionExec might need a similar fix.
* avoid allocation by pinning on the stack instead
* Consolidate panic propagation into RecordBatchReceiverStream
* Update docs / cleanup/
* Apply suggestions from code review
Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
* rename to be consistent and not deal with English pecularities
* Add a test and comments
* write test for drop cancel
* Add test fpr not driving to completion
* Do not drive all streams to error
* terminate early on panic
* tweak comments
* tweak comments
* use futures::stream
---------
Co-authored-by: Nicolae Vartolomei <nv...@nvartolomei.com>
Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
---
datafusion/core/src/physical_plan/analyze.rs | 32 +-
.../core/src/physical_plan/coalesce_partitions.rs | 75 ++--
datafusion/core/src/physical_plan/common.rs | 56 +--
datafusion/core/src/physical_plan/sorts/sort.rs | 16 +-
.../physical_plan/sorts/sort_preserving_merge.rs | 13 +-
datafusion/core/src/physical_plan/stream.rs | 382 +++++++++++++++++++--
datafusion/core/src/physical_plan/union.rs | 37 +-
datafusion/core/src/test/exec.rs | 230 +++++++++++--
8 files changed, 610 insertions(+), 231 deletions(-)
diff --git a/datafusion/core/src/physical_plan/analyze.rs b/datafusion/core/src/physical_plan/analyze.rs
index 39d715761f..9be68337b2 100644
--- a/datafusion/core/src/physical_plan/analyze.rs
+++ b/datafusion/core/src/physical_plan/analyze.rs
@@ -29,10 +29,9 @@ use crate::{
};
use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch};
use futures::StreamExt;
-use tokio::task::JoinSet;
use super::expressions::PhysicalSortExpr;
-use super::stream::RecordBatchStreamAdapter;
+use super::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter};
use super::{Distribution, SendableRecordBatchStream};
use datafusion_execution::TaskContext;
@@ -121,23 +120,15 @@ impl ExecutionPlan for AnalyzeExec {
// Gather futures that will run each input partition in
// parallel (on a separate tokio task) using a JoinSet to
// cancel outstanding futures on drop
- let mut set = JoinSet::new();
let num_input_partitions = self.input.output_partitioning().partition_count();
+ let mut builder =
+ RecordBatchReceiverStream::builder(self.schema(), num_input_partitions);
for input_partition in 0..num_input_partitions {
- let input_stream = self.input.execute(input_partition, context.clone());
-
- set.spawn(async move {
- let mut total_rows = 0;
- let mut input_stream = input_stream?;
- while let Some(batch) = input_stream.next().await {
- let batch = batch?;
- total_rows += batch.num_rows();
- }
- Ok(total_rows) as Result<usize>
- });
+ builder.run_input(self.input.clone(), input_partition, context.clone());
}
+ // Create future that computes thefinal output
let start = Instant::now();
let captured_input = self.input.clone();
let captured_schema = self.schema.clone();
@@ -146,18 +137,11 @@ impl ExecutionPlan for AnalyzeExec {
// future that gathers the results from all the tasks in the
// JoinSet that computes the overall row count and final
// record batch
+ let mut input_stream = builder.build();
let output = async move {
let mut total_rows = 0;
- while let Some(res) = set.join_next().await {
- // translate join errors (aka task panic's) into ExecutionErrors
- match res {
- Ok(row_count) => total_rows += row_count?,
- Err(e) => {
- return Err(DataFusionError::Execution(format!(
- "Join error in AnalyzeExec: {e}"
- )))
- }
- }
+ while let Some(batch) = input_stream.next().await.transpose()? {
+ total_rows += batch.num_rows();
}
let duration = Instant::now() - start;
diff --git a/datafusion/core/src/physical_plan/coalesce_partitions.rs b/datafusion/core/src/physical_plan/coalesce_partitions.rs
index 11d7021ca9..66700cd9e7 100644
--- a/datafusion/core/src/physical_plan/coalesce_partitions.rs
+++ b/datafusion/core/src/physical_plan/coalesce_partitions.rs
@@ -20,25 +20,19 @@
use std::any::Any;
use std::sync::Arc;
-use std::task::Poll;
-
-use futures::Stream;
-use tokio::sync::mpsc;
use arrow::datatypes::SchemaRef;
-use arrow::record_batch::RecordBatch;
-use super::common::AbortOnDropMany;
use super::expressions::PhysicalSortExpr;
use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
-use super::{RecordBatchStream, Statistics};
+use super::stream::{ObservedStream, RecordBatchReceiverStream};
+use super::Statistics;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{
DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning,
};
use super::SendableRecordBatchStream;
-use crate::physical_plan::common::spawn_execution;
use datafusion_execution::TaskContext;
/// Merge execution plan executes partitions in parallel and combines them into a single
@@ -137,27 +131,17 @@ impl ExecutionPlan for CoalescePartitionsExec {
// use a stream that allows each sender to put in at
// least one result in an attempt to maximize
// parallelism.
- let (sender, receiver) =
- mpsc::channel::<Result<RecordBatch>>(input_partitions);
+ let mut builder =
+ RecordBatchReceiverStream::builder(self.schema(), input_partitions);
// spawn independent tasks whose resulting streams (of batches)
// are sent to the channel for consumption.
- let mut join_handles = Vec::with_capacity(input_partitions);
for part_i in 0..input_partitions {
- join_handles.push(spawn_execution(
- self.input.clone(),
- sender.clone(),
- part_i,
- context.clone(),
- ));
+ builder.run_input(self.input.clone(), part_i, context.clone());
}
- Ok(Box::pin(MergeStream {
- input: receiver,
- schema: self.schema(),
- baseline_metrics,
- drop_helper: AbortOnDropMany(join_handles),
- }))
+ let stream = builder.build();
+ Ok(Box::pin(ObservedStream::new(stream, baseline_metrics)))
}
}
}
@@ -183,32 +167,6 @@ impl ExecutionPlan for CoalescePartitionsExec {
}
}
-struct MergeStream {
- schema: SchemaRef,
- input: mpsc::Receiver<Result<RecordBatch>>,
- baseline_metrics: BaselineMetrics,
- #[allow(unused)]
- drop_helper: AbortOnDropMany<()>,
-}
-
-impl Stream for MergeStream {
- type Item = Result<RecordBatch>;
-
- fn poll_next(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<Option<Self::Item>> {
- let poll = self.input.poll_recv(cx);
- self.baseline_metrics.record_poll(poll)
- }
-}
-
-impl RecordBatchStream for MergeStream {
- fn schema(&self) -> SchemaRef {
- self.schema.clone()
- }
-}
-
#[cfg(test)]
mod tests {
@@ -218,7 +176,9 @@ mod tests {
use super::*;
use crate::physical_plan::{collect, common};
use crate::prelude::SessionContext;
- use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
+ use crate::test::exec::{
+ assert_strong_count_converges_to_zero, BlockingExec, PanicExec,
+ };
use crate::test::{self, assert_is_pending};
#[tokio::test]
@@ -270,4 +230,19 @@ mod tests {
Ok(())
}
+
+ #[tokio::test]
+ #[should_panic(expected = "PanickingStream did panic")]
+ async fn test_panic() {
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+ let schema =
+ Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
+
+ let panicking_exec = Arc::new(PanicExec::new(Arc::clone(&schema), 2));
+ let coalesce_partitions_exec =
+ Arc::new(CoalescePartitionsExec::new(panicking_exec));
+
+ collect(coalesce_partitions_exec, task_ctx).await.unwrap();
+ }
}
diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs
index 98239557cb..2f296ce462 100644
--- a/datafusion/core/src/physical_plan/common.rs
+++ b/datafusion/core/src/physical_plan/common.rs
@@ -21,15 +21,13 @@ use super::SendableRecordBatchStream;
use crate::error::{DataFusionError, Result};
use crate::execution::memory_pool::MemoryReservation;
use crate::physical_plan::stream::RecordBatchReceiverStream;
-use crate::physical_plan::{displayable, ColumnStatistics, ExecutionPlan, Statistics};
+use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics};
use arrow::datatypes::Schema;
use arrow::ipc::writer::{FileWriter, IpcWriteOptions};
use arrow::record_batch::RecordBatch;
-use datafusion_execution::TaskContext;
use datafusion_physical_expr::expressions::{BinaryExpr, Column};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use futures::{Future, StreamExt, TryStreamExt};
-use log::debug;
use parking_lot::Mutex;
use pin_project_lite::pin_project;
use std::fs;
@@ -37,7 +35,6 @@ use std::fs::{metadata, File};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::task::{Context, Poll};
-use tokio::sync::mpsc;
use tokio::task::JoinHandle;
/// [`MemoryReservation`] used across query execution streams
@@ -96,42 +93,6 @@ fn build_file_list_recurse(
Ok(())
}
-/// Spawns a task to the tokio threadpool and writes its outputs to the provided mpsc sender
-pub(crate) fn spawn_execution(
- input: Arc<dyn ExecutionPlan>,
- output: mpsc::Sender<Result<RecordBatch>>,
- partition: usize,
- context: Arc<TaskContext>,
-) -> JoinHandle<()> {
- tokio::spawn(async move {
- let mut stream = match input.execute(partition, context) {
- Err(e) => {
- // If send fails, plan being torn down,
- // there is no place to send the error.
- output.send(Err(e)).await.ok();
- debug!(
- "Stopping execution: error executing input: {}",
- displayable(input.as_ref()).one_line()
- );
- return;
- }
- Ok(stream) => stream,
- };
-
- while let Some(item) = stream.next().await {
- // If send fails, plan being torn down,
- // there is no place to send the error.
- if output.send(item).await.is_err() {
- debug!(
- "Stopping execution: output is gone, plan cancelling: {}",
- displayable(input.as_ref()).one_line()
- );
- return;
- }
- }
- })
-}
-
/// If running in a tokio context spawns the execution of `stream` to a separate task
/// allowing it to execute in parallel with an intermediate buffer of size `buffer`
pub(crate) fn spawn_buffered(
@@ -139,14 +100,15 @@ pub(crate) fn spawn_buffered(
buffer: usize,
) -> SendableRecordBatchStream {
// Use tokio only if running from a tokio context (#2201)
- let handle = match tokio::runtime::Handle::try_current() {
- Ok(handle) => handle,
- Err(_) => return input,
+ if tokio::runtime::Handle::try_current().is_err() {
+ return input;
};
- let schema = input.schema();
- let (sender, receiver) = mpsc::channel(buffer);
- let join = handle.spawn(async move {
+ let mut builder = RecordBatchReceiverStream::builder(input.schema(), buffer);
+
+ let sender = builder.tx();
+
+ builder.spawn(async move {
while let Some(item) = input.next().await {
if sender.send(item).await.is_err() {
return;
@@ -154,7 +116,7 @@ pub(crate) fn spawn_buffered(
}
});
- RecordBatchReceiverStream::create(&schema, receiver, join)
+ builder.build()
}
/// Computes the statistics for an in-memory RecordBatch
diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs
index 3e3a79495b..53177310cc 100644
--- a/datafusion/core/src/physical_plan/sorts/sort.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort.rs
@@ -52,7 +52,7 @@ use std::io::BufReader;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tempfile::NamedTempFile;
-use tokio::sync::mpsc::{Receiver, Sender};
+use tokio::sync::mpsc::Sender;
use tokio::task;
struct ExternalSorterMetrics {
@@ -373,18 +373,16 @@ fn read_spill_as_stream(
path: NamedTempFile,
schema: SchemaRef,
) -> Result<SendableRecordBatchStream> {
- let (sender, receiver): (Sender<Result<RecordBatch>>, Receiver<Result<RecordBatch>>) =
- tokio::sync::mpsc::channel(2);
- let join_handle = task::spawn_blocking(move || {
+ let mut builder = RecordBatchReceiverStream::builder(schema, 2);
+ let sender = builder.tx();
+
+ builder.spawn_blocking(move || {
if let Err(e) = read_spill(sender, path.path()) {
error!("Failure while reading spill file: {:?}. Error: {}", path, e);
}
});
- Ok(RecordBatchReceiverStream::create(
- &schema,
- receiver,
- join_handle,
- ))
+
+ Ok(builder.build())
}
fn write_sorted(
diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
index 95cc23a20c..eb2725ade1 100644
--- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -792,9 +792,12 @@ mod tests {
let mut streams = Vec::with_capacity(partition_count);
for partition in 0..partition_count {
- let (sender, receiver) = tokio::sync::mpsc::channel(1);
+ let mut builder = RecordBatchReceiverStream::builder(schema.clone(), 1);
+
+ let sender = builder.tx();
+
let mut stream = batches.execute(partition, task_ctx.clone()).unwrap();
- let join_handle = tokio::spawn(async move {
+ builder.spawn(async move {
while let Some(batch) = stream.next().await {
sender.send(batch).await.unwrap();
// This causes the MergeStream to wait for more input
@@ -802,11 +805,7 @@ mod tests {
}
});
- streams.push(RecordBatchReceiverStream::create(
- &schema,
- receiver,
- join_handle,
- ));
+ streams.push(builder.build());
}
let metrics = ExecutionPlanMetricsSet::new();
diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs
index 2190022bc5..75a0f45e1e 100644
--- a/datafusion/core/src/physical_plan/stream.rs
+++ b/datafusion/core/src/physical_plan/stream.rs
@@ -17,43 +17,205 @@
//! Stream wrappers for physical operators
+use std::sync::Arc;
+
use crate::error::Result;
+use crate::physical_plan::displayable;
use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
-use futures::{Stream, StreamExt};
+use datafusion_common::DataFusionError;
+use datafusion_execution::TaskContext;
+use futures::stream::BoxStream;
+use futures::{Future, Stream, StreamExt};
+use log::debug;
use pin_project_lite::pin_project;
-use tokio::task::JoinHandle;
+use tokio::sync::mpsc::{Receiver, Sender};
+use tokio::task::JoinSet;
use tokio_stream::wrappers::ReceiverStream;
-use super::common::AbortOnDropSingle;
-use super::{RecordBatchStream, SendableRecordBatchStream};
+use super::metrics::BaselineMetrics;
+use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
-/// Adapter for a tokio [`ReceiverStream`] that implements the
-/// [`SendableRecordBatchStream`]
-/// interface
-pub struct RecordBatchReceiverStream {
+/// Builder for [`RecordBatchReceiverStream`] that propagates errors
+/// and panic's correctly.
+///
+/// [`RecordBatchReceiverStream`] is used to spawn one or more tasks
+/// that produce `RecordBatch`es and send them to a single
+/// `Receiver` which can improve parallelism.
+///
+/// This also handles propagating panic`s and canceling the tasks.
+pub struct RecordBatchReceiverStreamBuilder {
+ tx: Sender<Result<RecordBatch>>,
+ rx: Receiver<Result<RecordBatch>>,
schema: SchemaRef,
+ join_set: JoinSet<()>,
+}
+
+impl RecordBatchReceiverStreamBuilder {
+ /// create new channels with the specified buffer size
+ pub fn new(schema: SchemaRef, capacity: usize) -> Self {
+ let (tx, rx) = tokio::sync::mpsc::channel(capacity);
- inner: ReceiverStream<Result<RecordBatch>>,
+ Self {
+ tx,
+ rx,
+ schema,
+ join_set: JoinSet::new(),
+ }
+ }
+
+ /// Get a handle for sending [`RecordBatch`]es to the output
+ pub fn tx(&self) -> Sender<Result<RecordBatch>> {
+ self.tx.clone()
+ }
+
+ /// Spawn task that will be aborted if this builder (or the stream
+ /// built from it) are dropped
+ ///
+ /// this is often used to spawn tasks that write to the sender
+ /// retrieved from `Self::tx`
+ pub fn spawn<F>(&mut self, task: F)
+ where
+ F: Future<Output = ()>,
+ F: Send + 'static,
+ {
+ self.join_set.spawn(task);
+ }
+
+ /// Spawn a blocking task that will be aborted if this builder (or the stream
+ /// built from it) are dropped
+ ///
+ /// this is often used to spawn tasks that write to the sender
+ /// retrieved from `Self::tx`
+ pub fn spawn_blocking<F>(&mut self, f: F)
+ where
+ F: FnOnce(),
+ F: Send + 'static,
+ {
+ self.join_set.spawn_blocking(f);
+ }
+
+ /// runs the input_partition of the `input` ExecutionPlan on the
+ /// tokio threadpool and writes its outputs to this stream
+ ///
+ /// If the input partition produces an error, the error will be
+ /// sent to the output stream and no further results are sent.
+ pub(crate) fn run_input(
+ &mut self,
+ input: Arc<dyn ExecutionPlan>,
+ partition: usize,
+ context: Arc<TaskContext>,
+ ) {
+ let output = self.tx();
+
+ self.spawn(async move {
+ let mut stream = match input.execute(partition, context) {
+ Err(e) => {
+ // If send fails, the plan being torn down, there
+ // is no place to send the error and no reason to continue.
+ output.send(Err(e)).await.ok();
+ debug!(
+ "Stopping execution: error executing input: {}",
+ displayable(input.as_ref()).one_line()
+ );
+ return;
+ }
+ Ok(stream) => stream,
+ };
+
+ // Transfer batches from inner stream to the output tx
+ // immediately.
+ while let Some(item) = stream.next().await {
+ let is_err = item.is_err();
+
+ // If send fails, plan being torn down, there is no
+ // place to send the error and no reason to continue.
+ if output.send(item).await.is_err() {
+ debug!(
+ "Stopping execution: output is gone, plan cancelling: {}",
+ displayable(input.as_ref()).one_line()
+ );
+ return;
+ }
+
+ // stop after the first error is encontered (don't
+ // drive all streams to completion)
+ if is_err {
+ debug!(
+ "Stopping execution: plan returned error: {}",
+ displayable(input.as_ref()).one_line()
+ );
+ return;
+ }
+ }
+ });
+ }
+
+ /// Create a stream of all `RecordBatch`es written to `tx`
+ pub fn build(self) -> SendableRecordBatchStream {
+ let Self {
+ tx,
+ rx,
+ schema,
+ mut join_set,
+ } = self;
- #[allow(dead_code)]
- drop_helper: AbortOnDropSingle<()>,
+ // don't need tx
+ drop(tx);
+
+ // future that checks the result of the join set, and propagates panic if seen
+ let check = async move {
+ while let Some(result) = join_set.join_next().await {
+ match result {
+ Ok(()) => continue, // nothing to report
+ // This means a tokio task error, likely a panic
+ Err(e) => {
+ if e.is_panic() {
+ // resume on the main thread
+ std::panic::resume_unwind(e.into_panic());
+ } else {
+ // This should only occur if the task is
+ // cancelled, which would only occur if
+ // the JoinSet were aborted, which in turn
+ // would imply that the receiver has been
+ // dropped and this code is not running
+ return Some(Err(DataFusionError::Internal(format!(
+ "Non Panic Task error: {e}"
+ ))));
+ }
+ }
+ }
+ }
+ None
+ };
+
+ let check_stream = futures::stream::once(check)
+ // unwrap Option / only return the error
+ .filter_map(|item| async move { item });
+
+ // Merge the streams together so whichever is ready first
+ // produces the batch
+ let inner =
+ futures::stream::select(ReceiverStream::new(rx), check_stream).boxed();
+
+ Box::pin(RecordBatchReceiverStream { schema, inner })
+ }
+}
+
+/// Adapter for a tokio [`ReceiverStream`] that implements the
+/// [`SendableRecordBatchStream`] interface and propagates panics and
+/// errors. Use [`Self::builder`] to construct one.
+pub struct RecordBatchReceiverStream {
+ schema: SchemaRef,
+ inner: BoxStream<'static, Result<RecordBatch>>,
}
impl RecordBatchReceiverStream {
- /// Construct a new [`RecordBatchReceiverStream`] which will send
- /// batches of the specified schema from `inner`
- pub fn create(
- schema: &SchemaRef,
- rx: tokio::sync::mpsc::Receiver<Result<RecordBatch>>,
- join_handle: JoinHandle<()>,
- ) -> SendableRecordBatchStream {
- let schema = schema.clone();
- let inner = ReceiverStream::new(rx);
- Box::pin(Self {
- schema,
- inner,
- drop_helper: AbortOnDropSingle::new(join_handle),
- })
+ /// Create a builder with an internal buffer of capacity batches.
+ pub fn builder(
+ schema: SchemaRef,
+ capacity: usize,
+ ) -> RecordBatchReceiverStreamBuilder {
+ RecordBatchReceiverStreamBuilder::new(schema, capacity)
}
}
@@ -126,3 +288,173 @@ where
self.schema.clone()
}
}
+
+/// Stream wrapper that records `BaselineMetrics` for a particular
+/// `[SendableRecordBatchStream]` (likely a partition)
+pub(crate) struct ObservedStream {
+ inner: SendableRecordBatchStream,
+ baseline_metrics: BaselineMetrics,
+}
+
+impl ObservedStream {
+ pub fn new(
+ inner: SendableRecordBatchStream,
+ baseline_metrics: BaselineMetrics,
+ ) -> Self {
+ Self {
+ inner,
+ baseline_metrics,
+ }
+ }
+}
+
+impl RecordBatchStream for ObservedStream {
+ fn schema(&self) -> arrow::datatypes::SchemaRef {
+ self.inner.schema()
+ }
+}
+
+impl futures::Stream for ObservedStream {
+ type Item = Result<RecordBatch>;
+
+ fn poll_next(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Option<Self::Item>> {
+ let poll = self.inner.poll_next_unpin(cx);
+ self.baseline_metrics.record_poll(poll)
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ use arrow_schema::{DataType, Field, Schema};
+
+ use crate::{
+ execution::context::SessionContext,
+ test::exec::{
+ assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec,
+ },
+ };
+
+ fn schema() -> SchemaRef {
+ Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]))
+ }
+
+ #[tokio::test]
+ #[should_panic(expected = "PanickingStream did panic")]
+ async fn record_batch_receiver_stream_propagates_panics() {
+ let schema = schema();
+
+ let num_partitions = 10;
+ let input = PanicExec::new(schema.clone(), num_partitions);
+ consume(input, 10).await
+ }
+
+ #[tokio::test]
+ #[should_panic(expected = "PanickingStream did panic: 1")]
+ async fn record_batch_receiver_stream_propagates_panics_early_shutdown() {
+ let schema = schema();
+
+ // make 2 partitions, second partition panics before the first
+ let num_partitions = 2;
+ let input = PanicExec::new(schema.clone(), num_partitions)
+ .with_partition_panic(0, 10)
+ .with_partition_panic(1, 3); // partition 1 should panic first (after 3 )
+
+ // ensure that the panic results in an early shutdown (that
+ // everything stops after the first panic).
+
+ // Since the stream reads every other batch: (0,1,0,1,0,panic)
+ // so should not exceed 5 batches prior to the panic
+ let max_batches = 5;
+ consume(input, max_batches).await
+ }
+
+ #[tokio::test]
+ async fn record_batch_receiver_stream_drop_cancel() {
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+ let schema = schema();
+
+ // Make an input that never proceeds
+ let input = BlockingExec::new(schema.clone(), 1);
+ let refs = input.refs();
+
+ // Configure a RecordBatchReceiverStream to consume the input
+ let mut builder = RecordBatchReceiverStream::builder(schema, 2);
+ builder.run_input(Arc::new(input), 0, task_ctx.clone());
+ let stream = builder.build();
+
+ // input should still be present
+ assert!(std::sync::Weak::strong_count(&refs) > 0);
+
+ // drop the stream, ensure the refs go to zero
+ drop(stream);
+ assert_strong_count_converges_to_zero(refs).await;
+ }
+
+ #[tokio::test]
+ /// Ensure that if an error is received in one stream, the
+ /// `RecordBatchReceiverStream` stops early and does not drive
+ /// other streams to completion.
+ async fn record_batch_receiver_stream_error_does_not_drive_completion() {
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+ let schema = schema();
+
+ // make an input that will error twice
+ let error_stream = MockExec::new(
+ vec![
+ Err(DataFusionError::Execution("Test1".to_string())),
+ Err(DataFusionError::Execution("Test2".to_string())),
+ ],
+ schema.clone(),
+ )
+ .with_use_task(false);
+
+ let mut builder = RecordBatchReceiverStream::builder(schema, 2);
+ builder.run_input(Arc::new(error_stream), 0, task_ctx.clone());
+ let mut stream = builder.build();
+
+ // get the first result, which should be an error
+ let first_batch = stream.next().await.unwrap();
+ let first_err = first_batch.unwrap_err();
+ assert_eq!(first_err.to_string(), "Execution error: Test1");
+
+ // There should be no more batches produced (should not get the second error)
+ assert!(stream.next().await.is_none());
+ }
+
+ /// Consumes all the input's partitions into a
+ /// RecordBatchReceiverStream and runs it to completion
+ ///
+ /// panic's if more than max_batches is seen,
+ async fn consume(input: PanicExec, max_batches: usize) {
+ let session_ctx = SessionContext::new();
+ let task_ctx = session_ctx.task_ctx();
+
+ let input = Arc::new(input);
+ let num_partitions = input.output_partitioning().partition_count();
+
+ // Configure a RecordBatchReceiverStream to consume all the input partitions
+ let mut builder =
+ RecordBatchReceiverStream::builder(input.schema(), num_partitions);
+ for partition in 0..num_partitions {
+ builder.run_input(input.clone(), partition, task_ctx.clone());
+ }
+ let mut stream = builder.build();
+
+ // drain the stream until it is complete, panic'ing on error
+ let mut num_batches = 0;
+ while let Some(next) = stream.next().await {
+ next.unwrap();
+ num_batches += 1;
+ assert!(
+ num_batches < max_batches,
+ "Got the limit of {num_batches} batches before seeing panic"
+ );
+ }
+ }
+}
diff --git a/datafusion/core/src/physical_plan/union.rs b/datafusion/core/src/physical_plan/union.rs
index 5cf25fbe02..f2b936cf53 100644
--- a/datafusion/core/src/physical_plan/union.rs
+++ b/datafusion/core/src/physical_plan/union.rs
@@ -30,7 +30,7 @@ use arrow::{
record_batch::RecordBatch,
};
use datafusion_common::{DFSchemaRef, DataFusionError};
-use futures::{Stream, StreamExt};
+use futures::Stream;
use itertools::Itertools;
use log::{debug, trace, warn};
@@ -41,6 +41,7 @@ use super::{
SendableRecordBatchStream, Statistics,
};
use crate::physical_plan::common::get_meet_of_orderings;
+use crate::physical_plan::stream::ObservedStream;
use crate::{
error::Result,
physical_plan::{expressions, metrics::BaselineMetrics},
@@ -560,40 +561,6 @@ impl Stream for CombinedRecordBatchStream {
}
}
-/// Stream wrapper that records `BaselineMetrics` for a particular
-/// partition
-struct ObservedStream {
- inner: SendableRecordBatchStream,
- baseline_metrics: BaselineMetrics,
-}
-
-impl ObservedStream {
- fn new(inner: SendableRecordBatchStream, baseline_metrics: BaselineMetrics) -> Self {
- Self {
- inner,
- baseline_metrics,
- }
- }
-}
-
-impl RecordBatchStream for ObservedStream {
- fn schema(&self) -> arrow::datatypes::SchemaRef {
- self.inner.schema()
- }
-}
-
-impl futures::Stream for ObservedStream {
- type Item = Result<RecordBatch>;
-
- fn poll_next(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<Option<Self::Item>> {
- let poll = self.inner.poll_next_unpin(cx);
- self.baseline_metrics.record_poll(poll)
- }
-}
-
fn col_stats_union(
mut left: ColumnStatistics,
right: ColumnStatistics,
diff --git a/datafusion/core/src/test/exec.rs b/datafusion/core/src/test/exec.rs
index bce7d08a5c..41a0a1b4d0 100644
--- a/datafusion/core/src/test/exec.rs
+++ b/datafusion/core/src/test/exec.rs
@@ -31,7 +31,6 @@ use arrow::{
};
use futures::Stream;
-use crate::execution::context::TaskContext;
use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::physical_plan::{
common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream,
@@ -41,6 +40,9 @@ use crate::{
error::{DataFusionError, Result},
physical_plan::stream::RecordBatchReceiverStream,
};
+use crate::{
+ execution::context::TaskContext, physical_plan::stream::RecordBatchStreamAdapter,
+};
/// Index into the data that has been returned so far
#[derive(Debug, Default, Clone)]
@@ -114,22 +116,40 @@ impl RecordBatchStream for TestStream {
}
}
-/// A Mock ExecutionPlan that can be used for writing tests of other ExecutionPlans
-///
+/// A Mock ExecutionPlan that can be used for writing tests of other
+/// ExecutionPlans
#[derive(Debug)]
pub struct MockExec {
/// the results to send back
data: Vec<Result<RecordBatch>>,
schema: SchemaRef,
+ /// if true (the default), sends data using a separate task to to ensure the
+ /// batches are not available without this stream yielding first
+ use_task: bool,
}
impl MockExec {
- /// Create a new exec with a single partition that returns the
- /// record batches in this Exec. Note the batches are not produced
- /// immediately (the caller has to actually yield and another task
- /// must run) to ensure any poll loops are correct.
+ /// Create a new `MockExec` with a single partition that returns
+ /// the specified `Results`s.
+ ///
+ /// By default, the batches are not produced immediately (the
+ /// caller has to actually yield and another task must run) to
+ /// ensure any poll loops are correct. This behavior can be
+ /// changed with `with_use_task`
pub fn new(data: Vec<Result<RecordBatch>>, schema: SchemaRef) -> Self {
- Self { data, schema }
+ Self {
+ data,
+ schema,
+ use_task: true,
+ }
+ }
+
+ /// If `use_task` is true (the default) then the batches are sent
+ /// back using a separate task to ensure the underlying stream is
+ /// not immediately ready
+ pub fn with_use_task(mut self, use_task: bool) -> Self {
+ self.use_task = use_task;
+ self
}
}
@@ -179,26 +199,30 @@ impl ExecutionPlan for MockExec {
})
.collect();
- let (tx, rx) = tokio::sync::mpsc::channel(2);
-
- // task simply sends data in order but in a separate
- // thread (to ensure the batches are not available without the
- // DelayedStream yielding).
- let join_handle = tokio::task::spawn(async move {
- for batch in data {
- println!("Sending batch via delayed stream");
- if let Err(e) = tx.send(batch).await {
- println!("ERROR batch via delayed stream: {e}");
+ if self.use_task {
+ let mut builder = RecordBatchReceiverStream::builder(self.schema(), 2);
+ // send data in order but in a separate task (to ensure
+ // the batches are not available without the stream
+ // yielding).
+ let tx = builder.tx();
+ builder.spawn(async move {
+ for batch in data {
+ println!("Sending batch via delayed stream");
+ if let Err(e) = tx.send(batch).await {
+ println!("ERROR batch via delayed stream: {e}");
+ }
}
- }
- });
-
- // returned stream simply reads off the rx stream
- Ok(RecordBatchReceiverStream::create(
- &self.schema,
- rx,
- join_handle,
- ))
+ });
+ // returned stream simply reads off the rx stream
+ Ok(builder.build())
+ } else {
+ // make an input that will error
+ let stream = futures::stream::iter(data);
+ Ok(Box::pin(RecordBatchStreamAdapter::new(
+ self.schema(),
+ stream,
+ )))
+ }
}
fn fmt_as(
@@ -307,12 +331,13 @@ impl ExecutionPlan for BarrierExec {
) -> Result<SendableRecordBatchStream> {
assert!(partition < self.data.len());
- let (tx, rx) = tokio::sync::mpsc::channel(2);
+ let mut builder = RecordBatchReceiverStream::builder(self.schema(), 2);
// task simply sends data in order after barrier is reached
let data = self.data[partition].clone();
let b = self.barrier.clone();
- let join_handle = tokio::task::spawn(async move {
+ let tx = builder.tx();
+ builder.spawn(async move {
println!("Partition {partition} waiting on barrier");
b.wait().await;
for batch in data {
@@ -324,11 +349,7 @@ impl ExecutionPlan for BarrierExec {
});
// returned stream simply reads off the rx stream
- Ok(RecordBatchReceiverStream::create(
- &self.schema,
- rx,
- join_handle,
- ))
+ Ok(builder.build())
}
fn fmt_as(
@@ -643,3 +664,144 @@ pub async fn assert_strong_count_converges_to_zero<T>(refs: Weak<T>) {
.await
.unwrap();
}
+
+///
+
+/// Execution plan that emits streams that panics.
+///
+/// This is useful to test panic handling of certain execution plans.
+#[derive(Debug)]
+pub struct PanicExec {
+ /// Schema that is mocked by this plan.
+ schema: SchemaRef,
+
+ /// Number of output partitions. Each partition will produce this
+ /// many empty output record batches prior to panicing
+ batches_until_panics: Vec<usize>,
+}
+
+impl PanicExec {
+ /// Create new [`PanickingExec`] with a give schema and number of
+ /// partitions, which will each panic immediately.
+ pub fn new(schema: SchemaRef, n_partitions: usize) -> Self {
+ Self {
+ schema,
+ batches_until_panics: vec![0; n_partitions],
+ }
+ }
+
+ /// Set the number of batches prior to panic for a partition
+ pub fn with_partition_panic(mut self, partition: usize, count: usize) -> Self {
+ self.batches_until_panics[partition] = count;
+ self
+ }
+}
+
+impl ExecutionPlan for PanicExec {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn schema(&self) -> SchemaRef {
+ Arc::clone(&self.schema)
+ }
+
+ fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+ // this is a leaf node and has no children
+ vec![]
+ }
+
+ fn output_partitioning(&self) -> Partitioning {
+ let num_partitions = self.batches_until_panics.len();
+ Partitioning::UnknownPartitioning(num_partitions)
+ }
+
+ fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
+ None
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ _: Vec<Arc<dyn ExecutionPlan>>,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ Err(DataFusionError::Internal(format!(
+ "Children cannot be replaced in {:?}",
+ self
+ )))
+ }
+
+ fn execute(
+ &self,
+ partition: usize,
+ _context: Arc<TaskContext>,
+ ) -> Result<SendableRecordBatchStream> {
+ Ok(Box::pin(PanicStream {
+ partition,
+ batches_until_panic: self.batches_until_panics[partition],
+ schema: Arc::clone(&self.schema),
+ ready: false,
+ }))
+ }
+
+ fn fmt_as(
+ &self,
+ t: DisplayFormatType,
+ f: &mut std::fmt::Formatter,
+ ) -> std::fmt::Result {
+ match t {
+ DisplayFormatType::Default => {
+ write!(f, "PanickingExec",)
+ }
+ }
+ }
+
+ fn statistics(&self) -> Statistics {
+ unimplemented!()
+ }
+}
+
+/// A [`RecordBatchStream`] that yields every other batch and panics
+/// after `batches_until_panic` batches have been produced.
+///
+/// Useful for testing the behavior of streams on panic
+#[derive(Debug)]
+struct PanicStream {
+ /// Which partition was this
+ partition: usize,
+ /// How may batches will be produced until panic
+ batches_until_panic: usize,
+ /// Schema mocked by this stream.
+ schema: SchemaRef,
+ /// Should we return ready ?
+ ready: bool,
+}
+
+impl Stream for PanicStream {
+ type Item = Result<RecordBatch>;
+
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ if self.batches_until_panic > 0 {
+ if self.ready {
+ self.batches_until_panic -= 1;
+ self.ready = false;
+ let batch = RecordBatch::new_empty(self.schema.clone());
+ return Poll::Ready(Some(Ok(batch)));
+ } else {
+ self.ready = true;
+ // get called again
+ cx.waker().clone().wake();
+ return Poll::Pending;
+ }
+ }
+ panic!("PanickingStream did panic: {}", self.partition)
+ }
+}
+
+impl RecordBatchStream for PanicStream {
+ fn schema(&self) -> SchemaRef {
+ Arc::clone(&self.schema)
+ }
+}