You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2021/05/24 13:34:04 UTC

[GitHub] [arrow-datafusion] jhorstmann commented on a change in pull request #379: Sort preserving merge (#362)

jhorstmann commented on a change in pull request #379:
URL: https://github.com/apache/arrow-datafusion/pull/379#discussion_r637948151



##########
File path: datafusion/src/physical_plan/sort_preserving_merge.rs
##########
@@ -0,0 +1,955 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines the sort preserving merge plan
+
+use std::any::Any;
+use std::cmp::Ordering;
+use std::collections::VecDeque;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+use arrow::array::{ArrayRef, MutableArrayData};
+use arrow::compute::SortOptions;
+use async_trait::async_trait;
+use futures::channel::mpsc;
+use futures::stream::FusedStream;
+use futures::{Stream, StreamExt};
+
+use crate::arrow::datatypes::SchemaRef;
+use crate::arrow::error::ArrowError;
+use crate::arrow::{error::Result as ArrowResult, record_batch::RecordBatch};
+use crate::error::{DataFusionError, Result};
+use crate::physical_plan::common::spawn_execution;
+use crate::physical_plan::expressions::PhysicalSortExpr;
+use crate::physical_plan::{
+    DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr,
+    RecordBatchStream, SendableRecordBatchStream,
+};
+
+/// Sort preserving merge execution plan
+///
+/// This takes an input execution plan and a list of sort expressions, and
+/// provided each partition of the input plan is sorted with respect to
+/// these sort expressions, this operator will yield a single partition
+/// that is also sorted with respect to them
+#[derive(Debug)]
+pub struct SortPreservingMergeExec {
+    /// Input plan
+    input: Arc<dyn ExecutionPlan>,
+    /// Sort expressions
+    expr: Vec<PhysicalSortExpr>,
+    /// The target size of yielded batches
+    target_batch_size: usize,
+}
+
+impl SortPreservingMergeExec {
+    /// Create a new sort execution plan
+    pub fn new(
+        expr: Vec<PhysicalSortExpr>,
+        input: Arc<dyn ExecutionPlan>,
+        target_batch_size: usize,
+    ) -> Self {
+        Self {
+            input,
+            expr,
+            target_batch_size,
+        }
+    }
+
+    /// Input schema
+    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
+        &self.input
+    }
+
+    /// Sort expressions
+    pub fn expr(&self) -> &[PhysicalSortExpr] {
+        &self.expr
+    }
+}
+
+#[async_trait]
+impl ExecutionPlan for SortPreservingMergeExec {
+    /// Return a reference to Any that can be used for downcasting
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn schema(&self) -> SchemaRef {
+        self.input.schema()
+    }
+
+    /// Get the output partitioning of this plan
+    fn output_partitioning(&self) -> Partitioning {
+        Partitioning::UnknownPartitioning(1)
+    }
+
+    fn required_child_distribution(&self) -> Distribution {
+        Distribution::UnspecifiedDistribution
+    }
+
+    fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+        vec![self.input.clone()]
+    }
+
+    fn with_new_children(
+        &self,
+        children: Vec<Arc<dyn ExecutionPlan>>,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        match children.len() {
+            1 => Ok(Arc::new(SortPreservingMergeExec::new(
+                self.expr.clone(),
+                children[0].clone(),
+                self.target_batch_size,
+            ))),
+            _ => Err(DataFusionError::Internal(
+                "SortPreservingMergeExec wrong number of children".to_string(),
+            )),
+        }
+    }
+
+    async fn execute(&self, partition: usize) -> Result<SendableRecordBatchStream> {
+        if 0 != partition {
+            return Err(DataFusionError::Internal(format!(
+                "SortPreservingMergeExec invalid partition {}",
+                partition
+            )));
+        }
+
+        let input_partitions = self.input.output_partitioning().partition_count();
+        match input_partitions {
+            0 => Err(DataFusionError::Internal(
+                "SortPreservingMergeExec requires at least one input partition"
+                    .to_owned(),
+            )),
+            1 => {
+                // bypass if there is only one partition to merge
+                self.input.execute(0).await
+            }
+            _ => {
+                let streams = (0..input_partitions)
+                    .into_iter()
+                    .map(|part_i| {
+                        let (sender, receiver) = mpsc::channel(1);
+                        spawn_execution(self.input.clone(), sender, part_i);
+                        receiver
+                    })
+                    .collect();
+
+                Ok(Box::pin(SortPreservingMergeStream::new(
+                    streams,
+                    self.schema(),
+                    &self.expr,
+                    self.target_batch_size,
+                )))
+            }
+        }
+    }
+
+    fn fmt_as(
+        &self,
+        t: DisplayFormatType,
+        f: &mut std::fmt::Formatter,
+    ) -> std::fmt::Result {
+        match t {
+            DisplayFormatType::Default => {
+                let expr: Vec<String> = self.expr.iter().map(|e| e.to_string()).collect();
+                write!(f, "SortPreservingMergeExec: [{}]", expr.join(","))
+            }
+        }
+    }
+}
+
+/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of `PhysicalExpr` that when
+/// evaluated on the `RecordBatch` yield the sort keys.
+///
+/// Additionally it maintains a row cursor that can be advanced through the rows
+/// of the provided `RecordBatch`
+///
+/// `SortKeyCursor::compare` can then be used to compare the sort key pointed to by this
+/// row cursor, with that of another `SortKeyCursor`
+#[derive(Debug, Clone)]
+struct SortKeyCursor {
+    columns: Vec<ArrayRef>,
+    batch: RecordBatch,
+    cur_row: usize,
+    num_rows: usize,
+}
+
+impl SortKeyCursor {
+    fn new(batch: RecordBatch, sort_key: &[Arc<dyn PhysicalExpr>]) -> Result<Self> {
+        let columns = sort_key
+            .iter()
+            .map(|expr| Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())))
+            .collect::<Result<_>>()?;
+
+        Ok(Self {
+            cur_row: 0,
+            num_rows: batch.num_rows(),
+            columns,
+            batch,
+        })
+    }
+
+    fn is_finished(&self) -> bool {
+        self.num_rows == self.cur_row
+    }
+
+    fn advance(&mut self) -> usize {
+        assert!(!self.is_finished());
+        let t = self.cur_row;
+        self.cur_row += 1;
+        t
+    }
+
+    /// Compares the sort key pointed to by this instance's row cursor with that of another
+    fn compare(
+        &self,
+        other: &SortKeyCursor,
+        options: &[SortOptions],
+    ) -> Result<Ordering> {
+        if self.columns.len() != other.columns.len() {
+            return Err(DataFusionError::Internal(format!(
+                "SortKeyCursors had inconsistent column counts: {} vs {}",
+                self.columns.len(),
+                other.columns.len()
+            )));
+        }
+
+        if self.columns.len() != options.len() {
+            return Err(DataFusionError::Internal(format!(
+                "Incorrect number of SortOptions provided to SortKeyCursor::compare, expected {} got {}",
+                self.columns.len(),
+                options.len()
+            )));
+        }
+
+        let zipped = self
+            .columns
+            .iter()
+            .zip(other.columns.iter())
+            .zip(options.iter());
+
+        for ((l, r), sort_options) in zipped {
+            match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) {
+                (false, true) if sort_options.nulls_first => return Ok(Ordering::Less),
+                (false, true) => return Ok(Ordering::Greater),
+                (true, false) if sort_options.nulls_first => {
+                    return Ok(Ordering::Greater)
+                }
+                (true, false) => return Ok(Ordering::Less),
+                (false, false) => {}
+                (true, true) => {
+                    // TODO: Building the predicate each time is sub-optimal
+                    let c = arrow::array::build_compare(l.as_ref(), r.as_ref())?;
+                    match c(self.cur_row, other.cur_row) {
+                        Ordering::Equal => {}
+                        o if sort_options.descending => return Ok(o.reverse()),
+                        o => return Ok(o),
+                    }
+                }
+            }
+        }
+
+        Ok(Ordering::Equal)
+    }
+}
+
+/// A `RowIndex` identifies a specific row from those buffered
+/// by a `SortPreservingMergeStream`
+#[derive(Debug, Clone)]
+struct RowIndex {
+    /// The index of the stream
+    stream_idx: usize,
+    /// The index of the cursor within the stream's VecDequeue
+    cursor_idx: usize,
+    /// The row index
+    row_idx: usize,
+}
+
+#[derive(Debug)]
+struct SortPreservingMergeStream {
+    /// The schema of the RecordBatches yielded by this stream
+    schema: SchemaRef,
+    /// The sorted input streams to merge together
+    streams: Vec<mpsc::Receiver<ArrowResult<RecordBatch>>>,
+    /// For each input stream maintain a dequeue of SortKeyCursor
+    ///
+    /// Exhausted cursors will be popped off the front once all
+    /// their rows have been yielded to the output
+    cursors: Vec<VecDeque<SortKeyCursor>>,
+    /// The accumulated row indexes for the next record batch
+    in_progress: Vec<RowIndex>,
+    /// The physical expressions to sort by
+    column_expressions: Vec<Arc<dyn PhysicalExpr>>,
+    /// The sort options for each expression
+    sort_options: Vec<SortOptions>,
+    /// The desired RecordBatch size to yield
+    target_batch_size: usize,
+    /// If the stream has encountered an error
+    aborted: bool,
+}
+
+impl SortPreservingMergeStream {
+    fn new(
+        streams: Vec<mpsc::Receiver<ArrowResult<RecordBatch>>>,
+        schema: SchemaRef,
+        expressions: &[PhysicalSortExpr],
+        target_batch_size: usize,
+    ) -> Self {
+        Self {
+            schema,
+            cursors: vec![Default::default(); streams.len()],
+            streams,
+            column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(),
+            sort_options: expressions.iter().map(|x| x.options).collect(),
+            target_batch_size,
+            aborted: false,
+            in_progress: vec![],
+        }
+    }
+
+    /// If the stream at the given index is not exhausted, and the last cursor for the
+    /// stream is finished, poll the stream for the next RecordBatch and create a new
+    /// cursor for the stream from the returned result
+    fn maybe_poll_stream(
+        &mut self,
+        cx: &mut Context<'_>,
+        idx: usize,
+    ) -> Poll<ArrowResult<()>> {
+        if let Some(cursor) = &self.cursors[idx].back() {
+            if !cursor.is_finished() {
+                // Cursor is not finished - don't need a new RecordBatch yet
+                return Poll::Ready(Ok(()));
+            }
+        }
+
+        let stream = &mut self.streams[idx];
+        if stream.is_terminated() {
+            return Poll::Ready(Ok(()));
+        }
+
+        // Fetch a new record and create a cursor from it
+        match futures::ready!(stream.poll_next_unpin(cx)) {
+            None => return Poll::Ready(Ok(())),
+            Some(Err(e)) => {
+                return Poll::Ready(Err(e));
+            }
+            Some(Ok(batch)) => {
+                let cursor = match SortKeyCursor::new(batch, &self.column_expressions) {
+                    Ok(cursor) => cursor,
+                    Err(e) => {
+                        return Poll::Ready(Err(ArrowError::ExternalError(Box::new(e))));
+                    }
+                };
+                self.cursors[idx].push_back(cursor)
+            }
+        }
+
+        Poll::Ready(Ok(()))
+    }
+
+    /// Returns the index of the next stream to pull a row from, or None
+    /// if all cursors for all streams are exhausted
+    fn next_stream_idx(&self) -> Result<Option<usize>> {
+        let mut min_cursor: Option<(usize, &SortKeyCursor)> = None;
+        for (idx, candidate) in self.cursors.iter().enumerate() {

Review comment:
       For bigger number of partitions, storing the cursors in a BinaryHeap, sorted by their current item, would be beneficial.
   
   A rust implementation of that approach can be seen in [this blog post and the first comment under it][1]. I have implemented the same approach in java before. I agree with @alamb though to make it work first, and then optimize later.
   
   [1]: https://dev.to/creativcoder/merge-k-sorted-arrays-in-rust-1b2f




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org