You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/04/11 08:35:35 UTC

[arrow-datafusion] branch main updated: Specialize Primitive Cursor -- make sorts / merges on a single primitive column faster (#5897)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 4675a58eeb Specialize Primitive Cursor -- make sorts / merges on a single primitive column faster (#5897)
4675a58eeb is described below

commit 4675a58eeb13afbe4798b95d92492f6922ebf47e
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Tue Apr 11 09:35:28 2023 +0100

    Specialize Primitive Cursor -- make sorts / merges on a single primitive column faster (#5897)
    
    * Specialize PrimitiveCursor (#5882)
    
    * Toml format
    
    * Review feedback
---
 datafusion-cli/Cargo.lock                         |   2 +
 datafusion/core/Cargo.toml                        |   2 +
 datafusion/core/src/physical_plan/sorts/cursor.rs | 232 ++++++++++++++++++++++
 datafusion/core/src/physical_plan/sorts/merge.rs  |  28 ++-
 datafusion/core/src/physical_plan/sorts/stream.rs |  74 ++++++-
 5 files changed, 333 insertions(+), 5 deletions(-)

diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index 0f26e5720c..cd997b134d 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -698,6 +698,8 @@ version = "22.0.0"
 dependencies = [
  "ahash",
  "arrow",
+ "arrow-array",
+ "arrow-schema",
  "async-compression",
  "async-trait",
  "bytes",
diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml
index ed3b23a9af..0e15e45a46 100644
--- a/datafusion/core/Cargo.toml
+++ b/datafusion/core/Cargo.toml
@@ -57,6 +57,8 @@ unicode_expressions = ["datafusion-physical-expr/unicode_expressions", "datafusi
 ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] }
 apache-avro = { version = "0.14", optional = true }
 arrow = { workspace = true }
+arrow-array = { workspace = true }
+arrow-schema = { workspace = true }
 async-compression = { version = "0.3.14", features = ["bzip2", "gzip", "xz", "zstd", "futures-io", "tokio"], optional = true }
 async-trait = "0.1.41"
 bytes = "1.4"
diff --git a/datafusion/core/src/physical_plan/sorts/cursor.rs b/datafusion/core/src/physical_plan/sorts/cursor.rs
index 3507a5b224..7e8d600542 100644
--- a/datafusion/core/src/physical_plan/sorts/cursor.rs
+++ b/datafusion/core/src/physical_plan/sorts/cursor.rs
@@ -15,6 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use crate::physical_plan::sorts::sort::SortOptions;
+use arrow::buffer::ScalarBuffer;
+use arrow::datatypes::ArrowNativeTypeOp;
 use arrow::row::{Row, Rows};
 use std::cmp::Ordering;
 
@@ -93,3 +96,232 @@ impl Cursor for RowCursor {
         t
     }
 }
+
+/// A cursor over sorted, nullable [`ArrowNativeTypeOp`]
+///
+/// Note: comparing cursors with different `SortOptions` will yield an arbitrary ordering
+#[derive(Debug)]
+pub struct PrimitiveCursor<T: ArrowNativeTypeOp> {
+    values: ScalarBuffer<T>,
+    offset: usize,
+    // If nulls first, the first non-null index
+    // Otherwise, the first null index
+    null_threshold: usize,
+    options: SortOptions,
+}
+
+impl<T: ArrowNativeTypeOp> PrimitiveCursor<T> {
+    /// Create a new [`PrimitiveCursor`] from the provided `values` sorted according to `options`
+    pub fn new(options: SortOptions, values: ScalarBuffer<T>, null_count: usize) -> Self {
+        assert!(null_count <= values.len());
+
+        let null_threshold = match options.nulls_first {
+            true => null_count,
+            false => values.len() - null_count,
+        };
+
+        Self {
+            values,
+            offset: 0,
+            null_threshold,
+            options,
+        }
+    }
+
+    fn is_null(&self) -> bool {
+        (self.offset < self.null_threshold) == self.options.nulls_first
+    }
+
+    fn value(&self) -> T {
+        self.values[self.offset]
+    }
+}
+
+impl<T: ArrowNativeTypeOp> PartialEq for PrimitiveCursor<T> {
+    fn eq(&self, other: &Self) -> bool {
+        self.cmp(other).is_eq()
+    }
+}
+
+impl<T: ArrowNativeTypeOp> Eq for PrimitiveCursor<T> {}
+impl<T: ArrowNativeTypeOp> PartialOrd for PrimitiveCursor<T> {
+    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+        Some(self.cmp(other))
+    }
+}
+
+impl<T: ArrowNativeTypeOp> Ord for PrimitiveCursor<T> {
+    fn cmp(&self, other: &Self) -> Ordering {
+        match (self.is_null(), other.is_null()) {
+            (true, true) => Ordering::Equal,
+            (true, false) => match self.options.nulls_first {
+                true => Ordering::Less,
+                false => Ordering::Greater,
+            },
+            (false, true) => match self.options.nulls_first {
+                true => Ordering::Greater,
+                false => Ordering::Less,
+            },
+            (false, false) => {
+                let s_v = self.value();
+                let o_v = other.value();
+
+                match self.options.descending {
+                    true => o_v.compare(s_v),
+                    false => s_v.compare(o_v),
+                }
+            }
+        }
+    }
+}
+
+impl<T: ArrowNativeTypeOp> Cursor for PrimitiveCursor<T> {
+    fn is_finished(&self) -> bool {
+        self.offset == self.values.len()
+    }
+
+    fn advance(&mut self) -> usize {
+        let t = self.offset;
+        self.offset += 1;
+        t
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_primitive_nulls_first() {
+        let options = SortOptions {
+            descending: false,
+            nulls_first: true,
+        };
+
+        let buffer = ScalarBuffer::from(vec![i32::MAX, 1, 2, 3]);
+        let mut a = PrimitiveCursor::new(options, buffer, 1);
+        let buffer = ScalarBuffer::from(vec![1, 2, -2, -1, 1, 9]);
+        let mut b = PrimitiveCursor::new(options, buffer, 2);
+
+        // NULL == NULL
+        assert_eq!(a.cmp(&b), Ordering::Equal);
+        assert_eq!(a, b);
+
+        // NULL == NULL
+        b.advance();
+        assert_eq!(a.cmp(&b), Ordering::Equal);
+        assert_eq!(a, b);
+
+        // NULL < -2
+        b.advance();
+        assert_eq!(a.cmp(&b), Ordering::Less);
+
+        // 1 > -2
+        a.advance();
+        assert_eq!(a.cmp(&b), Ordering::Greater);
+
+        // 1 > -1
+        b.advance();
+        assert_eq!(a.cmp(&b), Ordering::Greater);
+
+        // 1 == 1
+        b.advance();
+        assert_eq!(a.cmp(&b), Ordering::Equal);
+        assert_eq!(a, b);
+
+        // 9 > 1
+        b.advance();
+        assert_eq!(a.cmp(&b), Ordering::Less);
+
+        // 9 > 2
+        a.advance();
+        assert_eq!(a.cmp(&b), Ordering::Less);
+
+        let options = SortOptions {
+            descending: false,
+            nulls_first: false,
+        };
+
+        let buffer = ScalarBuffer::from(vec![0, 1, i32::MIN, i32::MAX]);
+        let mut a = PrimitiveCursor::new(options, buffer, 2);
+        let buffer = ScalarBuffer::from(vec![-1, i32::MAX, i32::MIN]);
+        let mut b = PrimitiveCursor::new(options, buffer, 2);
+
+        // 0 > -1
+        assert_eq!(a.cmp(&b), Ordering::Greater);
+
+        // 0 < NULL
+        b.advance();
+        assert_eq!(a.cmp(&b), Ordering::Less);
+
+        // 1 < NULL
+        a.advance();
+        assert_eq!(a.cmp(&b), Ordering::Less);
+
+        // NULL = NULL
+        a.advance();
+        assert_eq!(a.cmp(&b), Ordering::Equal);
+        assert_eq!(a, b);
+
+        let options = SortOptions {
+            descending: true,
+            nulls_first: false,
+        };
+
+        let buffer = ScalarBuffer::from(vec![6, 1, i32::MIN, i32::MAX]);
+        let mut a = PrimitiveCursor::new(options, buffer, 3);
+        let buffer = ScalarBuffer::from(vec![67, -3, i32::MAX, i32::MIN]);
+        let mut b = PrimitiveCursor::new(options, buffer, 2);
+
+        // 6 > 67
+        assert_eq!(a.cmp(&b), Ordering::Greater);
+
+        // 6 < -3
+        b.advance();
+        assert_eq!(a.cmp(&b), Ordering::Less);
+
+        // 6 < NULL
+        b.advance();
+        assert_eq!(a.cmp(&b), Ordering::Less);
+
+        // 6 < NULL
+        b.advance();
+        assert_eq!(a.cmp(&b), Ordering::Less);
+
+        // NULL == NULL
+        a.advance();
+        assert_eq!(a.cmp(&b), Ordering::Equal);
+        assert_eq!(a, b);
+
+        let options = SortOptions {
+            descending: true,
+            nulls_first: true,
+        };
+
+        let buffer = ScalarBuffer::from(vec![i32::MIN, i32::MAX, 6, 3]);
+        let mut a = PrimitiveCursor::new(options, buffer, 2);
+        let buffer = ScalarBuffer::from(vec![i32::MAX, 4546, -3]);
+        let mut b = PrimitiveCursor::new(options, buffer, 1);
+
+        // NULL == NULL
+        assert_eq!(a.cmp(&b), Ordering::Equal);
+        assert_eq!(a, b);
+
+        // NULL == NULL
+        a.advance();
+        assert_eq!(a.cmp(&b), Ordering::Equal);
+        assert_eq!(a, b);
+
+        // NULL < 4546
+        b.advance();
+        assert_eq!(a.cmp(&b), Ordering::Less);
+
+        // 6 > 4546
+        a.advance();
+        assert_eq!(a.cmp(&b), Ordering::Greater);
+
+        // 6 < -3
+        b.advance();
+        assert_eq!(a.cmp(&b), Ordering::Less);
+    }
+}
diff --git a/datafusion/core/src/physical_plan/sorts/merge.rs b/datafusion/core/src/physical_plan/sorts/merge.rs
index 502c37e641..1ea89b9a81 100644
--- a/datafusion/core/src/physical_plan/sorts/merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/merge.rs
@@ -19,16 +19,31 @@ use crate::common::Result;
 use crate::physical_plan::metrics::MemTrackingMetrics;
 use crate::physical_plan::sorts::builder::BatchBuilder;
 use crate::physical_plan::sorts::cursor::Cursor;
-use crate::physical_plan::sorts::stream::{PartitionedStream, RowCursorStream};
+use crate::physical_plan::sorts::stream::{
+    PartitionedStream, PrimitiveCursorStream, RowCursorStream,
+};
 use crate::physical_plan::{
     PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream,
 };
 use arrow::datatypes::SchemaRef;
 use arrow::record_batch::RecordBatch;
+use arrow_array::downcast_primitive;
 use futures::Stream;
 use std::pin::Pin;
 use std::task::{ready, Context, Poll};
 
+macro_rules! primitive_merge_helper {
+    ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident) => {{
+        let streams = PrimitiveCursorStream::<$t>::new($sort, $streams);
+        return Ok(Box::pin(SortPreservingMergeStream::new(
+            Box::new(streams),
+            $schema,
+            $tracking_metrics,
+            $batch_size,
+        )));
+    }};
+}
+
 /// Perform a streaming merge of [`SendableRecordBatchStream`]
 pub(crate) fn streaming_merge(
     streams: Vec<SendableRecordBatchStream>,
@@ -37,8 +52,17 @@ pub(crate) fn streaming_merge(
     tracking_metrics: MemTrackingMetrics,
     batch_size: usize,
 ) -> Result<SendableRecordBatchStream> {
-    let streams = RowCursorStream::try_new(schema.as_ref(), expressions, streams)?;
+    // Special case single column comparisons with optimized cursor implementations
+    if expressions.len() == 1 {
+        let sort = expressions[0].clone();
+        let data_type = sort.expr.data_type(schema.as_ref())?;
+        downcast_primitive! {
+            data_type => (primitive_merge_helper, sort, streams, schema, tracking_metrics, batch_size),
+            _ => {}
+        }
+    }
 
+    let streams = RowCursorStream::try_new(schema.as_ref(), expressions, streams)?;
     Ok(Box::pin(SortPreservingMergeStream::new(
         Box::new(streams),
         schema,
diff --git a/datafusion/core/src/physical_plan/sorts/stream.rs b/datafusion/core/src/physical_plan/sorts/stream.rs
index 3fe68624f7..d6f49e58b5 100644
--- a/datafusion/core/src/physical_plan/sorts/stream.rs
+++ b/datafusion/core/src/physical_plan/sorts/stream.rs
@@ -16,13 +16,16 @@
 // under the License.
 
 use crate::common::Result;
-use crate::physical_plan::sorts::cursor::RowCursor;
+use crate::physical_plan::sorts::cursor::{PrimitiveCursor, RowCursor};
 use crate::physical_plan::SendableRecordBatchStream;
 use crate::physical_plan::{PhysicalExpr, PhysicalSortExpr};
+use arrow::array::{Array, ArrowPrimitiveType};
 use arrow::datatypes::Schema;
 use arrow::record_batch::RecordBatch;
 use arrow::row::{RowConverter, SortField};
+use datafusion_common::cast::as_primitive_array;
 use futures::stream::{Fuse, StreamExt};
+use std::marker::PhantomData;
 use std::sync::Arc;
 use std::task::{ready, Context, Poll};
 
@@ -75,7 +78,7 @@ impl FusedStreams {
 /// A [`PartitionedStream`] that wraps a set of [`SendableRecordBatchStream`]
 /// and computes [`RowCursor`] based on the provided [`PhysicalSortExpr`]
 #[derive(Debug)]
-pub(crate) struct RowCursorStream {
+pub struct RowCursorStream {
     /// Converter to convert output of physical expressions
     converter: RowConverter,
     /// The physical expressions to sort by
@@ -85,7 +88,7 @@ pub(crate) struct RowCursorStream {
 }
 
 impl RowCursorStream {
-    pub(crate) fn try_new(
+    pub fn try_new(
         schema: &Schema,
         expressions: &[PhysicalSortExpr],
         streams: Vec<SendableRecordBatchStream>,
@@ -139,3 +142,68 @@ impl PartitionedStream for RowCursorStream {
         }))
     }
 }
+
+/// Specialized stream for sorts on single primitive columns
+pub struct PrimitiveCursorStream<T: ArrowPrimitiveType> {
+    /// The physical expressions to sort by
+    sort: PhysicalSortExpr,
+    /// Input streams
+    streams: FusedStreams,
+    phantom: PhantomData<fn(T) -> T>,
+}
+
+impl<T: ArrowPrimitiveType> std::fmt::Debug for PrimitiveCursorStream<T> {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("PrimitiveCursorStream")
+            .field("data_type", &T::DATA_TYPE)
+            .field("num_streams", &self.streams)
+            .finish()
+    }
+}
+
+impl<T: ArrowPrimitiveType> PrimitiveCursorStream<T> {
+    pub fn new(sort: PhysicalSortExpr, streams: Vec<SendableRecordBatchStream>) -> Self {
+        let streams = streams.into_iter().map(|s| s.fuse()).collect();
+        Self {
+            sort,
+            streams: FusedStreams(streams),
+            phantom: Default::default(),
+        }
+    }
+
+    fn convert_batch(
+        &mut self,
+        batch: &RecordBatch,
+    ) -> Result<PrimitiveCursor<T::Native>> {
+        let value = self.sort.expr.evaluate(batch)?;
+        let array = value.into_array(batch.num_rows());
+        let array = as_primitive_array::<T>(array.as_ref())?;
+
+        Ok(PrimitiveCursor::new(
+            self.sort.options,
+            array.values().clone(),
+            array.null_count(),
+        ))
+    }
+}
+
+impl<T: ArrowPrimitiveType> PartitionedStream for PrimitiveCursorStream<T> {
+    type Output = Result<(PrimitiveCursor<T::Native>, RecordBatch)>;
+
+    fn partitions(&self) -> usize {
+        self.streams.0.len()
+    }
+
+    fn poll_next(
+        &mut self,
+        cx: &mut Context<'_>,
+        stream_idx: usize,
+    ) -> Poll<Option<Self::Output>> {
+        Poll::Ready(ready!(self.streams.poll_next(cx, stream_idx)).map(|r| {
+            r.and_then(|batch| {
+                let cursor = self.convert_batch(&batch)?;
+                Ok((cursor, batch))
+            })
+        }))
+    }
+}