You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/06/01 15:15:43 UTC

[arrow-datafusion] branch master updated: Sort preserving merge (#362) (#379)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new c794f2d  Sort preserving merge (#362) (#379)
c794f2d is described below

commit c794f2df539a10524566cb02b6158ee46cb1459a
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Tue Jun 1 16:15:32 2021 +0100

    Sort preserving merge (#362) (#379)
    
    * Add SortPreservingMergeExec (#362)
    
    * Size MutableArrayData based on in_progress length
    
    * make SortPreservingMergeStream::build_record_batch fallible
    
    * Test SortPreservingMerge with different RecordBatch sizes
    
    * fix logical merge conflict
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
 datafusion/src/physical_plan/common.rs             |  39 +-
 datafusion/src/physical_plan/merge.rs              |  29 +-
 datafusion/src/physical_plan/mod.rs                |   1 +
 .../src/physical_plan/sort_preserving_merge.rs     | 949 +++++++++++++++++++++
 4 files changed, 988 insertions(+), 30 deletions(-)

diff --git a/datafusion/src/physical_plan/common.rs b/datafusion/src/physical_plan/common.rs
index f1ed374..e60963b 100644
--- a/datafusion/src/physical_plan/common.rs
+++ b/datafusion/src/physical_plan/common.rs
@@ -22,13 +22,18 @@ use std::fs::metadata;
 use std::sync::Arc;
 use std::task::{Context, Poll};
 
-use super::{RecordBatchStream, SendableRecordBatchStream};
-use crate::error::{DataFusionError, Result};
-
 use arrow::datatypes::SchemaRef;
 use arrow::error::Result as ArrowResult;
 use arrow::record_batch::RecordBatch;
-use futures::{Stream, TryStreamExt};
+use futures::channel::mpsc;
+use futures::{SinkExt, Stream, StreamExt, TryStreamExt};
+use tokio::task::JoinHandle;
+
+use crate::arrow::error::ArrowError;
+use crate::error::{DataFusionError, Result};
+use crate::physical_plan::ExecutionPlan;
+
+use super::{RecordBatchStream, SendableRecordBatchStream};
 
 /// Stream of record batches
 pub struct SizedRecordBatchStream {
@@ -113,3 +118,29 @@ fn build_file_list_recurse(
     }
     Ok(())
 }
+
+/// Spawns a task to the tokio threadpool and writes its outputs to the provided mpsc sender
+pub(crate) fn spawn_execution(
+    input: Arc<dyn ExecutionPlan>,
+    mut output: mpsc::Sender<ArrowResult<RecordBatch>>,
+    partition: usize,
+) -> JoinHandle<()> {
+    tokio::spawn(async move {
+        let mut stream = match input.execute(partition).await {
+            Err(e) => {
+                // If send fails, plan being torn
+                // down, no place to send the error
+                let arrow_error = ArrowError::ExternalError(Box::new(e));
+                output.send(Err(arrow_error)).await.ok();
+                return;
+            }
+            Ok(stream) => stream,
+        };
+
+        while let Some(item) = stream.next().await {
+            // If send fails, plan being torn down,
+            // there is no place to send the error
+            output.send(item).await.ok();
+        }
+    })
+}
diff --git a/datafusion/src/physical_plan/merge.rs b/datafusion/src/physical_plan/merge.rs
index c65227c..a25f5c7 100644
--- a/datafusion/src/physical_plan/merge.rs
+++ b/datafusion/src/physical_plan/merge.rs
@@ -22,23 +22,19 @@ use std::any::Any;
 use std::sync::Arc;
 
 use futures::channel::mpsc;
-use futures::sink::SinkExt;
-use futures::stream::StreamExt;
 use futures::Stream;
 
 use async_trait::async_trait;
 
 use arrow::record_batch::RecordBatch;
-use arrow::{
-    datatypes::SchemaRef,
-    error::{ArrowError, Result as ArrowResult},
-};
+use arrow::{datatypes::SchemaRef, error::Result as ArrowResult};
 
 use super::RecordBatchStream;
 use crate::error::{DataFusionError, Result};
 use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning};
 
 use super::SendableRecordBatchStream;
+use crate::physical_plan::common::spawn_execution;
 use pin_project_lite::pin_project;
 
 /// Merge execution plan executes partitions in parallel and combines them into a single
@@ -121,26 +117,7 @@ impl ExecutionPlan for MergeExec {
                 // spawn independent tasks whose resulting streams (of batches)
                 // are sent to the channel for consumption.
                 for part_i in 0..input_partitions {
-                    let input = self.input.clone();
-                    let mut sender = sender.clone();
-                    tokio::spawn(async move {
-                        let mut stream = match input.execute(part_i).await {
-                            Err(e) => {
-                                // If send fails, plan being torn
-                                // down, no place to send the error
-                                let arrow_error = ArrowError::ExternalError(Box::new(e));
-                                sender.send(Err(arrow_error)).await.ok();
-                                return;
-                            }
-                            Ok(stream) => stream,
-                        };
-
-                        while let Some(item) = stream.next().await {
-                            // If send fails, plan being torn down,
-                            // there is no place to send the error
-                            sender.send(item).await.ok();
-                        }
-                    });
+                    spawn_execution(self.input.clone(), sender.clone(), part_i);
                 }
 
                 Ok(Box::pin(MergeStream {
diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs
index ae84b36..af6969c 100644
--- a/datafusion/src/physical_plan/mod.rs
+++ b/datafusion/src/physical_plan/mod.rs
@@ -608,6 +608,7 @@ pub mod projection;
 pub mod regex_expressions;
 pub mod repartition;
 pub mod sort;
+pub mod sort_preserving_merge;
 pub mod source;
 pub mod string_expressions;
 pub mod type_coercion;
diff --git a/datafusion/src/physical_plan/sort_preserving_merge.rs b/datafusion/src/physical_plan/sort_preserving_merge.rs
new file mode 100644
index 0000000..283294a
--- /dev/null
+++ b/datafusion/src/physical_plan/sort_preserving_merge.rs
@@ -0,0 +1,949 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines the sort preserving merge plan
+
+use std::any::Any;
+use std::cmp::Ordering;
+use std::collections::VecDeque;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+use arrow::array::{ArrayRef, MutableArrayData};
+use arrow::compute::SortOptions;
+use async_trait::async_trait;
+use futures::channel::mpsc;
+use futures::stream::FusedStream;
+use futures::{Stream, StreamExt};
+
+use crate::arrow::datatypes::SchemaRef;
+use crate::arrow::error::ArrowError;
+use crate::arrow::{error::Result as ArrowResult, record_batch::RecordBatch};
+use crate::error::{DataFusionError, Result};
+use crate::physical_plan::common::spawn_execution;
+use crate::physical_plan::expressions::PhysicalSortExpr;
+use crate::physical_plan::{
+    DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr,
+    RecordBatchStream, SendableRecordBatchStream,
+};
+
+/// Sort preserving merge execution plan
+///
+/// This takes an input execution plan and a list of sort expressions, and
+/// provided each partition of the input plan is sorted with respect to
+/// these sort expressions, this operator will yield a single partition
+/// that is also sorted with respect to them
+#[derive(Debug)]
+pub struct SortPreservingMergeExec {
+    /// Input plan
+    input: Arc<dyn ExecutionPlan>,
+    /// Sort expressions
+    expr: Vec<PhysicalSortExpr>,
+    /// The target size of yielded batches
+    target_batch_size: usize,
+}
+
+impl SortPreservingMergeExec {
+    /// Create a new sort execution plan
+    pub fn new(
+        expr: Vec<PhysicalSortExpr>,
+        input: Arc<dyn ExecutionPlan>,
+        target_batch_size: usize,
+    ) -> Self {
+        Self {
+            input,
+            expr,
+            target_batch_size,
+        }
+    }
+
+    /// Input schema
+    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
+        &self.input
+    }
+
+    /// Sort expressions
+    pub fn expr(&self) -> &[PhysicalSortExpr] {
+        &self.expr
+    }
+}
+
+#[async_trait]
+impl ExecutionPlan for SortPreservingMergeExec {
+    /// Return a reference to Any that can be used for downcasting
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn schema(&self) -> SchemaRef {
+        self.input.schema()
+    }
+
+    /// Get the output partitioning of this plan
+    fn output_partitioning(&self) -> Partitioning {
+        Partitioning::UnknownPartitioning(1)
+    }
+
+    fn required_child_distribution(&self) -> Distribution {
+        Distribution::UnspecifiedDistribution
+    }
+
+    fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+        vec![self.input.clone()]
+    }
+
+    fn with_new_children(
+        &self,
+        children: Vec<Arc<dyn ExecutionPlan>>,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        match children.len() {
+            1 => Ok(Arc::new(SortPreservingMergeExec::new(
+                self.expr.clone(),
+                children[0].clone(),
+                self.target_batch_size,
+            ))),
+            _ => Err(DataFusionError::Internal(
+                "SortPreservingMergeExec wrong number of children".to_string(),
+            )),
+        }
+    }
+
+    async fn execute(&self, partition: usize) -> Result<SendableRecordBatchStream> {
+        if 0 != partition {
+            return Err(DataFusionError::Internal(format!(
+                "SortPreservingMergeExec invalid partition {}",
+                partition
+            )));
+        }
+
+        let input_partitions = self.input.output_partitioning().partition_count();
+        match input_partitions {
+            0 => Err(DataFusionError::Internal(
+                "SortPreservingMergeExec requires at least one input partition"
+                    .to_owned(),
+            )),
+            1 => {
+                // bypass if there is only one partition to merge
+                self.input.execute(0).await
+            }
+            _ => {
+                let streams = (0..input_partitions)
+                    .into_iter()
+                    .map(|part_i| {
+                        let (sender, receiver) = mpsc::channel(1);
+                        spawn_execution(self.input.clone(), sender, part_i);
+                        receiver
+                    })
+                    .collect();
+
+                Ok(Box::pin(SortPreservingMergeStream::new(
+                    streams,
+                    self.schema(),
+                    &self.expr,
+                    self.target_batch_size,
+                )))
+            }
+        }
+    }
+
+    fn fmt_as(
+        &self,
+        t: DisplayFormatType,
+        f: &mut std::fmt::Formatter,
+    ) -> std::fmt::Result {
+        match t {
+            DisplayFormatType::Default => {
+                let expr: Vec<String> = self.expr.iter().map(|e| e.to_string()).collect();
+                write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))
+            }
+        }
+    }
+}
+
+/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of `PhysicalExpr` that when
+/// evaluated on the `RecordBatch` yield the sort keys.
+///
+/// Additionally it maintains a row cursor that can be advanced through the rows
+/// of the provided `RecordBatch`
+///
+/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to by this
+/// row cursor, with that of another `SortKeyCursor`
+#[derive(Debug, Clone)]
+struct SortKeyCursor {
+    columns: Vec<ArrayRef>,
+    batch: RecordBatch,
+    cur_row: usize,
+    num_rows: usize,
+}
+
+impl SortKeyCursor {
+    fn new(batch: RecordBatch, sort_key: &[Arc<dyn PhysicalExpr>]) -> Result<Self> {
+        let columns = sort_key
+            .iter()
+            .map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())))
+            .collect::<Result<_>>()?;
+
+        Ok(Self {
+            cur_row: 0,
+            num_rows: batch.num_rows(),
+            columns,
+            batch,
+        })
+    }
+
+    fn is_finished(&self) -> bool {
+        self.num_rows == self.cur_row
+    }
+
+    fn advance(&mut self) -> usize {
+        assert!(!self.is_finished());
+        let t = self.cur_row;
+        self.cur_row += 1;
+        t
+    }
+
+    /// Compares the sort key pointed to by this instance's row cursor with that of another
+    fn compare(
+        &self,
+        other: &SortKeyCursor,
+        options: &[SortOptions],
+    ) -> Result<Ordering> {
+        if self.columns.len() != other.columns.len() {
+            return Err(DataFusionError::Internal(format!(
+                "SortKeyCursors had inconsistent column counts: {} vs {}",
+                self.columns.len(),
+                other.columns.len()
+            )));
+        }
+
+        if self.columns.len() != options.len() {
+            return Err(DataFusionError::Internal(format!(
+                "Incorrect number of SortOptions provided to SortKeyCursor::compare, expected {} got {}",
+                self.columns.len(),
+                options.len()
+            )));
+        }
+
+        let zipped = self
+            .columns
+            .iter()
+            .zip(other.columns.iter())
+            .zip(options.iter());
+
+        for ((l, r), sort_options) in zipped {
+            match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) {
+                (false, true) if sort_options.nulls_first => return Ok(Ordering::Less),
+                (false, true) => return Ok(Ordering::Greater),
+                (true, false) if sort_options.nulls_first => {
+                    return Ok(Ordering::Greater)
+                }
+                (true, false) => return Ok(Ordering::Less),
+                (false, false) => {}
+                (true, true) => {
+                    // TODO: Building the predicate each time is sub-optimal
+                    let c = arrow::array::build_compare(l.as_ref(), r.as_ref())?;
+                    match c(self.cur_row, other.cur_row) {
+                        Ordering::Equal => {}
+                        o if sort_options.descending => return Ok(o.reverse()),
+                        o => return Ok(o),
+                    }
+                }
+            }
+        }
+
+        Ok(Ordering::Equal)
+    }
+}
+
+/// A `RowIndex` identifies a specific row from those buffered
+/// by a `SortPreservingMergeStream`
+#[derive(Debug, Clone)]
+struct RowIndex {
+    /// The index of the stream
+    stream_idx: usize,
+    /// The index of the cursor within the stream's VecDequeue
+    cursor_idx: usize,
+    /// The row index
+    row_idx: usize,
+}
+
+#[derive(Debug)]
+struct SortPreservingMergeStream {
+    /// The schema of the RecordBatches yielded by this stream
+    schema: SchemaRef,
+    /// The sorted input streams to merge together
+    streams: Vec<mpsc::Receiver<ArrowResult<RecordBatch>>>,
+    /// For each input stream maintain a dequeue of SortKeyCursor
+    ///
+    /// Exhausted cursors will be popped off the front once all
+    /// their rows have been yielded to the output
+    cursors: Vec<VecDeque<SortKeyCursor>>,
+    /// The accumulated row indexes for the next record batch
+    in_progress: Vec<RowIndex>,
+    /// The physical expressions to sort by
+    column_expressions: Vec<Arc<dyn PhysicalExpr>>,
+    /// The sort options for each expression
+    sort_options: Vec<SortOptions>,
+    /// The desired RecordBatch size to yield
+    target_batch_size: usize,
+    /// If the stream has encountered an error
+    aborted: bool,
+}
+
+impl SortPreservingMergeStream {
+    fn new(
+        streams: Vec<mpsc::Receiver<ArrowResult<RecordBatch>>>,
+        schema: SchemaRef,
+        expressions: &[PhysicalSortExpr],
+        target_batch_size: usize,
+    ) -> Self {
+        Self {
+            schema,
+            cursors: vec![Default::default(); streams.len()],
+            streams,
+            column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(),
+            sort_options: expressions.iter().map(|x| x.options).collect(),
+            target_batch_size,
+            aborted: false,
+            in_progress: vec![],
+        }
+    }
+
+    /// If the stream at the given index is not exhausted, and the last cursor for the
+    /// stream is finished, poll the stream for the next RecordBatch and create a new
+    /// cursor for the stream from the returned result
+    fn maybe_poll_stream(
+        &mut self,
+        cx: &mut Context<'_>,
+        idx: usize,
+    ) -> Poll<ArrowResult<()>> {
+        if let Some(cursor) = &self.cursors[idx].back() {
+            if !cursor.is_finished() {
+                // Cursor is not finished - don't need a new RecordBatch yet
+                return Poll::Ready(Ok(()));
+            }
+        }
+
+        let stream = &mut self.streams[idx];
+        if stream.is_terminated() {
+            return Poll::Ready(Ok(()));
+        }
+
+        // Fetch a new record and create a cursor from it
+        match futures::ready!(stream.poll_next_unpin(cx)) {
+            None => return Poll::Ready(Ok(())),
+            Some(Err(e)) => {
+                return Poll::Ready(Err(e));
+            }
+            Some(Ok(batch)) => {
+                let cursor = match SortKeyCursor::new(batch, &self.column_expressions) {
+                    Ok(cursor) => cursor,
+                    Err(e) => {
+                        return Poll::Ready(Err(ArrowError::ExternalError(Box::new(e))));
+                    }
+                };
+                self.cursors[idx].push_back(cursor)
+            }
+        }
+
+        Poll::Ready(Ok(()))
+    }
+
+    /// Returns the index of the next stream to pull a row from, or None
+    /// if all cursors for all streams are exhausted
+    fn next_stream_idx(&self) -> Result<Option<usize>> {
+        let mut min_cursor: Option<(usize, &SortKeyCursor)> = None;
+        for (idx, candidate) in self.cursors.iter().enumerate() {
+            if let Some(candidate) = candidate.back() {
+                if candidate.is_finished() {
+                    continue;
+                }
+
+                match min_cursor {
+                    None => min_cursor = Some((idx, candidate)),
+                    Some((_, ref min)) => {
+                        if min.compare(candidate, &self.sort_options)?
+                            == Ordering::Greater
+                        {
+                            min_cursor = Some((idx, candidate))
+                        }
+                    }
+                }
+            }
+        }
+
+        Ok(min_cursor.map(|(idx, _)| idx))
+    }
+
+    /// Drains the in_progress row indexes, and builds a new RecordBatch from them
+    ///
+    /// Will then drop any cursors for which all rows have been yielded to the output
+    fn build_record_batch(&mut self) -> ArrowResult<RecordBatch> {
+        // Mapping from stream index to the index of the first buffer from that stream
+        let mut buffer_idx = 0;
+        let mut stream_to_buffer_idx = Vec::with_capacity(self.cursors.len());
+
+        for cursors in &self.cursors {
+            stream_to_buffer_idx.push(buffer_idx);
+            buffer_idx += cursors.len();
+        }
+
+        let columns = self
+            .schema
+            .fields()
+            .iter()
+            .enumerate()
+            .map(|(column_idx, field)| {
+                let arrays = self
+                    .cursors
+                    .iter()
+                    .flat_map(|cursor| {
+                        cursor
+                            .iter()
+                            .map(|cursor| cursor.batch.column(column_idx).data())
+                    })
+                    .collect();
+
+                let mut array_data = MutableArrayData::new(
+                    arrays,
+                    field.is_nullable(),
+                    self.in_progress.len(),
+                );
+
+                for row_index in &self.in_progress {
+                    let buffer_idx =
+                        stream_to_buffer_idx[row_index.stream_idx] + row_index.cursor_idx;
+
+                    // TODO: Coalesce contiguous writes
+                    array_data.extend(
+                        buffer_idx,
+                        row_index.row_idx,
+                        row_index.row_idx + 1,
+                    );
+                }
+
+                arrow::array::make_array(array_data.freeze())
+            })
+            .collect();
+
+        self.in_progress.clear();
+
+        // New cursors are only created once the previous cursor for the stream
+        // is finished. This means all remaining rows from all but the last cursor
+        // for each stream have been yielded to the newly created record batch
+        //
+        // Additionally as `in_progress` has been drained, there are no longer
+        // any RowIndex's reliant on the cursor indexes
+        //
+        // We can therefore drop all but the last cursor for each stream
+        for cursors in &mut self.cursors {
+            if cursors.len() > 1 {
+                // Drain all but the last cursor
+                cursors.drain(0..(cursors.len() - 1));
+            }
+        }
+
+        RecordBatch::try_new(self.schema.clone(), columns)
+    }
+}
+
+impl Stream for SortPreservingMergeStream {
+    type Item = ArrowResult<RecordBatch>;
+
+    fn poll_next(
+        mut self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        if self.aborted {
+            return Poll::Ready(None);
+        }
+
+        // Ensure all non-exhausted streams have a cursor from which
+        // rows can be pulled
+        for i in 0..self.cursors.len() {
+            match futures::ready!(self.maybe_poll_stream(cx, i)) {
+                Ok(_) => {}
+                Err(e) => {
+                    self.aborted = true;
+                    return Poll::Ready(Some(Err(e)));
+                }
+            }
+        }
+
+        loop {
+            let stream_idx = match self.next_stream_idx() {
+                Ok(Some(idx)) => idx,
+                Ok(None) if self.in_progress.is_empty() => return Poll::Ready(None),
+                Ok(None) => return Poll::Ready(Some(self.build_record_batch())),
+                Err(e) => {
+                    self.aborted = true;
+                    return Poll::Ready(Some(Err(ArrowError::ExternalError(Box::new(
+                        e,
+                    )))));
+                }
+            };
+
+            let cursors = &mut self.cursors[stream_idx];
+            let cursor_idx = cursors.len() - 1;
+            let cursor = cursors.back_mut().unwrap();
+            let row_idx = cursor.advance();
+            let cursor_finished = cursor.is_finished();
+
+            self.in_progress.push(RowIndex {
+                stream_idx,
+                cursor_idx,
+                row_idx,
+            });
+
+            if self.in_progress.len() == self.target_batch_size {
+                return Poll::Ready(Some(self.build_record_batch()));
+            }
+
+            // If removed the last row from the cursor, need to fetch a new record
+            // batch if possible, before looping round again
+            if cursor_finished {
+                match futures::ready!(self.maybe_poll_stream(cx, stream_idx)) {
+                    Ok(_) => {}
+                    Err(e) => {
+                        self.aborted = true;
+                        return Poll::Ready(Some(Err(e)));
+                    }
+                }
+            }
+        }
+    }
+}
+
+impl RecordBatchStream for SortPreservingMergeStream {
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::iter::FromIterator;
+
+    use crate::arrow::array::{Int32Array, StringArray, TimestampNanosecondArray};
+    use crate::assert_batches_eq;
+    use crate::datasource::CsvReadOptions;
+    use crate::physical_plan::csv::CsvExec;
+    use crate::physical_plan::expressions::col;
+    use crate::physical_plan::memory::MemoryExec;
+    use crate::physical_plan::merge::MergeExec;
+    use crate::physical_plan::sort::SortExec;
+    use crate::physical_plan::{collect, common};
+    use crate::test;
+
+    use super::*;
+    use futures::SinkExt;
+    use tokio_stream::StreamExt;
+
+    #[tokio::test]
+    async fn test_merge() {
+        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
+        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
+            Some("a"),
+            Some("b"),
+            Some("c"),
+            Some("d"),
+            Some("e"),
+        ]));
+        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 4]));
+        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
+
+        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
+        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
+            Some("d"),
+            Some("e"),
+            Some("g"),
+            Some("h"),
+            Some("i"),
+        ]));
+        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
+        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
+        let schema = b1.schema();
+
+        let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap();
+        let merge = Arc::new(SortPreservingMergeExec::new(
+            vec![
+                PhysicalSortExpr {
+                    expr: col("b"),
+                    options: Default::default(),
+                },
+                PhysicalSortExpr {
+                    expr: col("c"),
+                    options: Default::default(),
+                },
+            ],
+            Arc::new(exec),
+            1024,
+        ));
+
+        let collected = collect(merge).await.unwrap();
+        assert_eq!(collected.len(), 1);
+
+        assert_batches_eq!(
+            &[
+                "+---+---+-------------------------------+",
+                "| a | b | c                             |",
+                "+---+---+-------------------------------+",
+                "| 1 | a | 1970-01-01 00:00:00.000000008 |",
+                "| 2 | b | 1970-01-01 00:00:00.000000007 |",
+                "| 7 | c | 1970-01-01 00:00:00.000000006 |",
+                "| 1 | d | 1970-01-01 00:00:00.000000004 |",
+                "| 9 | d | 1970-01-01 00:00:00.000000005 |",
+                "| 3 | e | 1970-01-01 00:00:00.000000004 |",
+                "| 2 | e | 1970-01-01 00:00:00.000000006 |",
+                "| 3 | g | 1970-01-01 00:00:00.000000002 |",
+                "| 4 | h | 1970-01-01 00:00:00.000000002 |",
+                "| 5 | i | 1970-01-01 00:00:00.000000006 |",
+                "+---+---+-------------------------------+",
+            ],
+            collected.as_slice()
+        );
+    }
+
+    async fn sorted_merge(
+        input: Arc<dyn ExecutionPlan>,
+        sort: Vec<PhysicalSortExpr>,
+    ) -> RecordBatch {
+        let merge = Arc::new(SortPreservingMergeExec::new(sort, input, 1024));
+        let mut result = collect(merge).await.unwrap();
+        assert_eq!(result.len(), 1);
+        result.remove(0)
+    }
+
+    async fn partition_sort(
+        input: Arc<dyn ExecutionPlan>,
+        sort: Vec<PhysicalSortExpr>,
+    ) -> RecordBatch {
+        let sort_exec =
+            Arc::new(SortExec::new_with_partitioning(sort.clone(), input, true));
+        sorted_merge(sort_exec, sort).await
+    }
+
+    async fn basic_sort(
+        src: Arc<dyn ExecutionPlan>,
+        sort: Vec<PhysicalSortExpr>,
+    ) -> RecordBatch {
+        let merge = Arc::new(MergeExec::new(src));
+        let sort_exec = Arc::new(SortExec::try_new(sort, merge).unwrap());
+        let mut result = collect(sort_exec).await.unwrap();
+        assert_eq!(result.len(), 1);
+        result.remove(0)
+    }
+
+    #[tokio::test]
+    async fn test_partition_sort() {
+        let schema = test::aggr_test_schema();
+        let partitions = 4;
+        let path =
+            test::create_partitioned_csv("aggregate_test_100.csv", partitions).unwrap();
+        let csv = Arc::new(
+            CsvExec::try_new(
+                &path,
+                CsvReadOptions::new().schema(&schema),
+                None,
+                1024,
+                None,
+            )
+            .unwrap(),
+        );
+
+        let sort = vec![
+            PhysicalSortExpr {
+                expr: col("c1"),
+                options: SortOptions {
+                    descending: true,
+                    nulls_first: true,
+                },
+            },
+            PhysicalSortExpr {
+                expr: col("c2"),
+                options: Default::default(),
+            },
+            PhysicalSortExpr {
+                expr: col("c7"),
+                options: SortOptions::default(),
+            },
+        ];
+
+        let basic = basic_sort(csv.clone(), sort.clone()).await;
+        let partition = partition_sort(csv, sort).await;
+
+        let basic = arrow::util::pretty::pretty_format_batches(&[basic]).unwrap();
+        let partition = arrow::util::pretty::pretty_format_batches(&[partition]).unwrap();
+
+        assert_eq!(basic, partition);
+    }
+
+    // Split the provided record batch into multiple batch_size record batches
+    fn split_batch(sorted: &RecordBatch, batch_size: usize) -> Vec<RecordBatch> {
+        let batches = (sorted.num_rows() + batch_size - 1) / batch_size;
+
+        // Split the sorted RecordBatch into multiple
+        (0..batches)
+            .into_iter()
+            .map(|batch_idx| {
+                let columns = (0..sorted.num_columns())
+                    .map(|column_idx| {
+                        let length =
+                            batch_size.min(sorted.num_rows() - batch_idx * batch_size);
+
+                        sorted
+                            .column(column_idx)
+                            .slice(batch_idx * batch_size, length)
+                    })
+                    .collect();
+
+                RecordBatch::try_new(sorted.schema(), columns).unwrap()
+            })
+            .collect()
+    }
+
+    async fn sorted_partitioned_input(
+        sort: Vec<PhysicalSortExpr>,
+        sizes: &[usize],
+    ) -> Arc<dyn ExecutionPlan> {
+        let schema = test::aggr_test_schema();
+        let partitions = 4;
+        let path =
+            test::create_partitioned_csv("aggregate_test_100.csv", partitions).unwrap();
+        let csv = Arc::new(
+            CsvExec::try_new(
+                &path,
+                CsvReadOptions::new().schema(&schema),
+                None,
+                1024,
+                None,
+            )
+            .unwrap(),
+        );
+
+        let sorted = basic_sort(csv, sort).await;
+        let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect();
+
+        Arc::new(MemoryExec::try_new(&split, sorted.schema(), None).unwrap())
+    }
+
+    #[tokio::test]
+    async fn test_partition_sort_streaming_input() {
+        let sort = vec![
+            // uint8
+            PhysicalSortExpr {
+                expr: col("c7"),
+                options: Default::default(),
+            },
+            // int16
+            PhysicalSortExpr {
+                expr: col("c4"),
+                options: Default::default(),
+            },
+            // utf-8
+            PhysicalSortExpr {
+                expr: col("c1"),
+                options: SortOptions::default(),
+            },
+            // utf-8
+            PhysicalSortExpr {
+                expr: col("c13"),
+                options: SortOptions::default(),
+            },
+        ];
+
+        let input = sorted_partitioned_input(sort.clone(), &[10, 3, 11]).await;
+        let basic = basic_sort(input.clone(), sort.clone()).await;
+        let partition = sorted_merge(input, sort).await;
+
+        assert_eq!(basic.num_rows(), 300);
+        assert_eq!(partition.num_rows(), 300);
+
+        let basic = arrow::util::pretty::pretty_format_batches(&[basic]).unwrap();
+        let partition = arrow::util::pretty::pretty_format_batches(&[partition]).unwrap();
+
+        assert_eq!(basic, partition);
+    }
+
+    #[tokio::test]
+    async fn test_partition_sort_streaming_input_output() {
+        let sort = vec![
+            // float64
+            PhysicalSortExpr {
+                expr: col("c12"),
+                options: Default::default(),
+            },
+            // utf-8
+            PhysicalSortExpr {
+                expr: col("c13"),
+                options: Default::default(),
+            },
+        ];
+
+        let input = sorted_partitioned_input(sort.clone(), &[10, 5, 13]).await;
+        let basic = basic_sort(input.clone(), sort.clone()).await;
+
+        let merge = Arc::new(SortPreservingMergeExec::new(sort, input, 23));
+        let merged = collect(merge).await.unwrap();
+
+        assert_eq!(merged.len(), 14);
+
+        assert_eq!(basic.num_rows(), 300);
+        assert_eq!(merged.iter().map(|x| x.num_rows()).sum::<usize>(), 300);
+
+        let basic = arrow::util::pretty::pretty_format_batches(&[basic]).unwrap();
+        let partition =
+            arrow::util::pretty::pretty_format_batches(merged.as_slice()).unwrap();
+
+        assert_eq!(basic, partition);
+    }
+
+    #[tokio::test]
+    async fn test_nulls() {
+        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
+        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
+            None,
+            Some("a"),
+            Some("b"),
+            Some("d"),
+            Some("e"),
+        ]));
+        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
+            Some(8),
+            None,
+            Some(6),
+            None,
+            Some(4),
+        ]));
+        let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
+
+        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
+        let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
+            None,
+            Some("b"),
+            Some("g"),
+            Some("h"),
+            Some("i"),
+        ]));
+        let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
+            Some(8),
+            None,
+            Some(5),
+            None,
+            Some(4),
+        ]));
+        let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
+        let schema = b1.schema();
+
+        let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap();
+        let merge = Arc::new(SortPreservingMergeExec::new(
+            vec![
+                PhysicalSortExpr {
+                    expr: col("b"),
+                    options: SortOptions {
+                        descending: false,
+                        nulls_first: true,
+                    },
+                },
+                PhysicalSortExpr {
+                    expr: col("c"),
+                    options: SortOptions {
+                        descending: false,
+                        nulls_first: false,
+                    },
+                },
+            ],
+            Arc::new(exec),
+            1024,
+        ));
+
+        let collected = collect(merge).await.unwrap();
+        assert_eq!(collected.len(), 1);
+
+        assert_batches_eq!(
+            &[
+                "+---+---+-------------------------------+",
+                "| a | b | c                             |",
+                "+---+---+-------------------------------+",
+                "| 1 |   | 1970-01-01 00:00:00.000000008 |",
+                "| 1 |   | 1970-01-01 00:00:00.000000008 |",
+                "| 2 | a |                               |",
+                "| 7 | b | 1970-01-01 00:00:00.000000006 |",
+                "| 2 | b |                               |",
+                "| 9 | d |                               |",
+                "| 3 | e | 1970-01-01 00:00:00.000000004 |",
+                "| 3 | g | 1970-01-01 00:00:00.000000005 |",
+                "| 4 | h |                               |",
+                "| 5 | i | 1970-01-01 00:00:00.000000004 |",
+                "+---+---+-------------------------------+",
+            ],
+            collected.as_slice()
+        );
+    }
+
+    #[tokio::test]
+    async fn test_async() {
+        let sort = vec![PhysicalSortExpr {
+            expr: col("c7"),
+            options: SortOptions::default(),
+        }];
+
+        let batches = sorted_partitioned_input(sort.clone(), &[5, 7, 3]).await;
+
+        let partition_count = batches.output_partitioning().partition_count();
+        let mut tasks = Vec::with_capacity(partition_count);
+        let mut streams = Vec::with_capacity(partition_count);
+
+        for partition in 0..partition_count {
+            let (mut sender, receiver) = mpsc::channel(1);
+            let mut stream = batches.execute(partition).await.unwrap();
+            let task = tokio::spawn(async move {
+                while let Some(batch) = stream.next().await {
+                    sender.send(batch).await.unwrap();
+                    // This causes the MergeStream to wait for more input
+                    tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
+                }
+            });
+            tasks.push(task);
+            streams.push(receiver);
+        }
+
+        let merge_stream = SortPreservingMergeStream::new(
+            streams,
+            batches.schema(),
+            sort.as_slice(),
+            1024,
+        );
+
+        let mut merged = common::collect(Box::pin(merge_stream)).await.unwrap();
+
+        // Propagate any errors
+        for task in tasks {
+            task.await.unwrap();
+        }
+
+        assert_eq!(merged.len(), 1);
+        let merged = merged.remove(0);
+        let basic = basic_sort(batches, sort.clone()).await;
+
+        let basic = arrow::util::pretty::pretty_format_batches(&[basic]).unwrap();
+        let partition = arrow::util::pretty::pretty_format_batches(&[merged]).unwrap();
+
+        assert_eq!(basic, partition);
+    }
+}