You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/04/12 10:14:46 UTC

[arrow-datafusion] branch main updated: Specialized Cursor for StringArray and BinaryArray (#5964)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new e17572cd19 Specialized Cursor for StringArray and BinaryArray (#5964)
e17572cd19 is described below

commit e17572cd19b30578614d3e2a7b6a08019bd23baf
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Wed Apr 12 11:14:39 2023 +0100

    Specialized Cursor for StringArray and BinaryArray (#5964)
    
    * Generify
    
    * Specialized cursor for StringArray and BinaryArray
    
    * fix clippy
    
    * Review feedback
    
    ---------
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
 datafusion/core/src/physical_plan/sorts/cursor.rs | 156 +++++++++++++++++-----
 datafusion/core/src/physical_plan/sorts/merge.rs  |  18 ++-
 datafusion/core/src/physical_plan/sorts/stream.rs |  30 ++---
 3 files changed, 148 insertions(+), 56 deletions(-)

diff --git a/datafusion/core/src/physical_plan/sorts/cursor.rs b/datafusion/core/src/physical_plan/sorts/cursor.rs
index 7e8d600542..a9e5122130 100644
--- a/datafusion/core/src/physical_plan/sorts/cursor.rs
+++ b/datafusion/core/src/physical_plan/sorts/cursor.rs
@@ -19,6 +19,8 @@ use crate::physical_plan::sorts::sort::SortOptions;
 use arrow::buffer::ScalarBuffer;
 use arrow::datatypes::ArrowNativeTypeOp;
 use arrow::row::{Row, Rows};
+use arrow_array::types::ByteArrayType;
+use arrow_array::{Array, ArrowPrimitiveType, GenericByteArray, PrimitiveArray};
 use std::cmp::Ordering;
 
 /// A [`Cursor`] for [`Rows`]
@@ -97,12 +99,90 @@ impl Cursor for RowCursor {
     }
 }
 
-/// A cursor over sorted, nullable [`ArrowNativeTypeOp`]
+/// An [`Array`] that can be converted into [`FieldValues`]
+pub trait FieldArray: Array + 'static {
+    type Values: FieldValues;
+
+    fn values(&self) -> Self::Values;
+}
+
+/// A comparable set of non-nullable values
+pub trait FieldValues {
+    type Value: ?Sized;
+
+    fn len(&self) -> usize;
+
+    fn compare(a: &Self::Value, b: &Self::Value) -> Ordering;
+
+    fn value(&self, idx: usize) -> &Self::Value;
+}
+
+impl<T: ArrowPrimitiveType> FieldArray for PrimitiveArray<T> {
+    type Values = PrimitiveValues<T::Native>;
+
+    fn values(&self) -> Self::Values {
+        PrimitiveValues(self.values().clone())
+    }
+}
+
+#[derive(Debug)]
+pub struct PrimitiveValues<T: ArrowNativeTypeOp>(ScalarBuffer<T>);
+
+impl<T: ArrowNativeTypeOp> FieldValues for PrimitiveValues<T> {
+    type Value = T;
+
+    fn len(&self) -> usize {
+        self.0.len()
+    }
+
+    #[inline]
+    fn compare(a: &Self::Value, b: &Self::Value) -> Ordering {
+        T::compare(*a, *b)
+    }
+
+    #[inline]
+    fn value(&self, idx: usize) -> &Self::Value {
+        &self.0[idx]
+    }
+}
+
+impl<T: ByteArrayType> FieldArray for GenericByteArray<T> {
+    type Values = Self;
+
+    fn values(&self) -> Self::Values {
+        // Once https://github.com/apache/arrow-rs/pull/4048 is released
+        // Could potentially destructure array into buffers to reduce codegen,
+        // in a similar vein to what is done for PrimitiveArray
+        self.clone()
+    }
+}
+
+impl<T: ByteArrayType> FieldValues for GenericByteArray<T> {
+    type Value = T::Native;
+
+    fn len(&self) -> usize {
+        Array::len(self)
+    }
+
+    #[inline]
+    fn compare(a: &Self::Value, b: &Self::Value) -> Ordering {
+        let a: &[u8] = a.as_ref();
+        let b: &[u8] = b.as_ref();
+        a.cmp(b)
+    }
+
+    #[inline]
+    fn value(&self, idx: usize) -> &Self::Value {
+        self.value(idx)
+    }
+}
+
+/// A cursor over sorted, nullable [`FieldValues`]
 ///
 /// Note: comparing cursors with different `SortOptions` will yield an arbitrary ordering
 #[derive(Debug)]
-pub struct PrimitiveCursor<T: ArrowNativeTypeOp> {
-    values: ScalarBuffer<T>,
+pub struct FieldCursor<T: FieldValues> {
+    values: T,
     offset: usize,
     // If nulls first, the first non-null index
     // Otherwise, the first null index
@@ -110,18 +190,16 @@ pub struct PrimitiveCursor<T: ArrowNativeTypeOp> {
     options: SortOptions,
 }
 
-impl<T: ArrowNativeTypeOp> PrimitiveCursor<T> {
-    /// Create a new [`PrimitiveCursor`] from the provided `values` sorted according to `options`
-    pub fn new(options: SortOptions, values: ScalarBuffer<T>, null_count: usize) -> Self {
-        assert!(null_count <= values.len());
-
+impl<T: FieldValues> FieldCursor<T> {
+    /// Create a new [`FieldCursor`] from the provided `values` sorted according to `options`
+    pub fn new<A: FieldArray<Values = T>>(options: SortOptions, array: &A) -> Self {
         let null_threshold = match options.nulls_first {
-            true => null_count,
-            false => values.len() - null_count,
+            true => array.null_count(),
+            false => array.len() - array.null_count(),
         };
 
         Self {
-            values,
+            values: array.values(),
             offset: 0,
             null_threshold,
             options,
@@ -131,26 +209,22 @@ impl<T: ArrowNativeTypeOp> PrimitiveCursor<T> {
     fn is_null(&self) -> bool {
         (self.offset < self.null_threshold) == self.options.nulls_first
     }
-
-    fn value(&self) -> T {
-        self.values[self.offset]
-    }
 }
 
-impl<T: ArrowNativeTypeOp> PartialEq for PrimitiveCursor<T> {
+impl<T: FieldValues> PartialEq for FieldCursor<T> {
     fn eq(&self, other: &Self) -> bool {
         self.cmp(other).is_eq()
     }
 }
 
-impl<T: ArrowNativeTypeOp> Eq for PrimitiveCursor<T> {}
-impl<T: ArrowNativeTypeOp> PartialOrd for PrimitiveCursor<T> {
+impl<T: FieldValues> Eq for FieldCursor<T> {}
+impl<T: FieldValues> PartialOrd for FieldCursor<T> {
     fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
         Some(self.cmp(other))
     }
 }
 
-impl<T: ArrowNativeTypeOp> Ord for PrimitiveCursor<T> {
+impl<T: FieldValues> Ord for FieldCursor<T> {
     fn cmp(&self, other: &Self) -> Ordering {
         match (self.is_null(), other.is_null()) {
             (true, true) => Ordering::Equal,
@@ -163,19 +237,19 @@ impl<T: ArrowNativeTypeOp> Ord for PrimitiveCursor<T> {
                 false => Ordering::Less,
             },
             (false, false) => {
-                let s_v = self.value();
-                let o_v = other.value();
+                let s_v = self.values.value(self.offset);
+                let o_v = other.values.value(other.offset);
 
                 match self.options.descending {
-                    true => o_v.compare(s_v),
-                    false => s_v.compare(o_v),
+                    true => T::compare(o_v, s_v),
+                    false => T::compare(s_v, o_v),
                 }
             }
         }
     }
 }
 
-impl<T: ArrowNativeTypeOp> Cursor for PrimitiveCursor<T> {
+impl<T: FieldValues> Cursor for FieldCursor<T> {
     fn is_finished(&self) -> bool {
         self.offset == self.values.len()
     }
@@ -191,6 +265,24 @@ impl<T: ArrowNativeTypeOp> Cursor for PrimitiveCursor<T> {
 mod tests {
     use super::*;
 
+    fn new_primitive(
+        options: SortOptions,
+        values: ScalarBuffer<i32>,
+        null_count: usize,
+    ) -> FieldCursor<PrimitiveValues<i32>> {
+        let null_threshold = match options.nulls_first {
+            true => null_count,
+            false => values.len() - null_count,
+        };
+
+        FieldCursor {
+            offset: 0,
+            values: PrimitiveValues(values),
+            null_threshold,
+            options,
+        }
+    }
+
     #[test]
     fn test_primitive_nulls_first() {
         let options = SortOptions {
@@ -199,9 +291,9 @@ mod tests {
         };
 
         let buffer = ScalarBuffer::from(vec![i32::MAX, 1, 2, 3]);
-        let mut a = PrimitiveCursor::new(options, buffer, 1);
+        let mut a = new_primitive(options, buffer, 1);
         let buffer = ScalarBuffer::from(vec![1, 2, -2, -1, 1, 9]);
-        let mut b = PrimitiveCursor::new(options, buffer, 2);
+        let mut b = new_primitive(options, buffer, 2);
 
         // NULL == NULL
         assert_eq!(a.cmp(&b), Ordering::Equal);
@@ -243,9 +335,9 @@ mod tests {
         };
 
         let buffer = ScalarBuffer::from(vec![0, 1, i32::MIN, i32::MAX]);
-        let mut a = PrimitiveCursor::new(options, buffer, 2);
+        let mut a = new_primitive(options, buffer, 2);
         let buffer = ScalarBuffer::from(vec![-1, i32::MAX, i32::MIN]);
-        let mut b = PrimitiveCursor::new(options, buffer, 2);
+        let mut b = new_primitive(options, buffer, 2);
 
         // 0 > -1
         assert_eq!(a.cmp(&b), Ordering::Greater);
@@ -269,9 +361,9 @@ mod tests {
         };
 
         let buffer = ScalarBuffer::from(vec![6, 1, i32::MIN, i32::MAX]);
-        let mut a = PrimitiveCursor::new(options, buffer, 3);
+        let mut a = new_primitive(options, buffer, 3);
         let buffer = ScalarBuffer::from(vec![67, -3, i32::MAX, i32::MIN]);
-        let mut b = PrimitiveCursor::new(options, buffer, 2);
+        let mut b = new_primitive(options, buffer, 2);
 
         // 6 > 67
         assert_eq!(a.cmp(&b), Ordering::Greater);
@@ -299,9 +391,9 @@ mod tests {
         };
 
         let buffer = ScalarBuffer::from(vec![i32::MIN, i32::MAX, 6, 3]);
-        let mut a = PrimitiveCursor::new(options, buffer, 2);
+        let mut a = new_primitive(options, buffer, 2);
         let buffer = ScalarBuffer::from(vec![i32::MAX, 4546, -3]);
-        let mut b = PrimitiveCursor::new(options, buffer, 1);
+        let mut b = new_primitive(options, buffer, 1);
 
         // NULL == NULL
         assert_eq!(a.cmp(&b), Ordering::Equal);
diff --git a/datafusion/core/src/physical_plan/sorts/merge.rs b/datafusion/core/src/physical_plan/sorts/merge.rs
index 1ea89b9a81..7e2d986e9d 100644
--- a/datafusion/core/src/physical_plan/sorts/merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/merge.rs
@@ -20,21 +20,27 @@ use crate::physical_plan::metrics::MemTrackingMetrics;
 use crate::physical_plan::sorts::builder::BatchBuilder;
 use crate::physical_plan::sorts::cursor::Cursor;
 use crate::physical_plan::sorts::stream::{
-    PartitionedStream, PrimitiveCursorStream, RowCursorStream,
+    FieldCursorStream, PartitionedStream, RowCursorStream,
 };
 use crate::physical_plan::{
     PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream,
 };
-use arrow::datatypes::SchemaRef;
+use arrow::datatypes::{DataType, SchemaRef};
 use arrow::record_batch::RecordBatch;
-use arrow_array::downcast_primitive;
+use arrow_array::*;
 use futures::Stream;
 use std::pin::Pin;
 use std::task::{ready, Context, Poll};
 
 macro_rules! primitive_merge_helper {
+    ($t:ty, $($v:ident),+) => {
+        merge_helper!(PrimitiveArray<$t>, $($v),+)
+    };
+}
+
+macro_rules! merge_helper {
     ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident) => {{
-        let streams = PrimitiveCursorStream::<$t>::new($sort, $streams);
+        let streams = FieldCursorStream::<$t>::new($sort, $streams);
         return Ok(Box::pin(SortPreservingMergeStream::new(
             Box::new(streams),
             $schema,
@@ -58,6 +64,10 @@ pub(crate) fn streaming_merge(
         let data_type = sort.expr.data_type(schema.as_ref())?;
         downcast_primitive! {
             data_type => (primitive_merge_helper, sort, streams, schema, tracking_metrics, batch_size),
+            DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, tracking_metrics, batch_size)
+            DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, tracking_metrics, batch_size)
+            DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, tracking_metrics, batch_size)
+            DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, tracking_metrics, batch_size)
             _ => {}
         }
     }
diff --git a/datafusion/core/src/physical_plan/sorts/stream.rs b/datafusion/core/src/physical_plan/sorts/stream.rs
index d6f49e58b5..9de6e260db 100644
--- a/datafusion/core/src/physical_plan/sorts/stream.rs
+++ b/datafusion/core/src/physical_plan/sorts/stream.rs
@@ -16,14 +16,13 @@
 // under the License.
 
 use crate::common::Result;
-use crate::physical_plan::sorts::cursor::{PrimitiveCursor, RowCursor};
+use crate::physical_plan::sorts::cursor::{FieldArray, FieldCursor, RowCursor};
 use crate::physical_plan::SendableRecordBatchStream;
 use crate::physical_plan::{PhysicalExpr, PhysicalSortExpr};
-use arrow::array::{Array, ArrowPrimitiveType};
+use arrow::array::Array;
 use arrow::datatypes::Schema;
 use arrow::record_batch::RecordBatch;
 use arrow::row::{RowConverter, SortField};
-use datafusion_common::cast::as_primitive_array;
 use futures::stream::{Fuse, StreamExt};
 use std::marker::PhantomData;
 use std::sync::Arc;
@@ -144,7 +143,7 @@ impl PartitionedStream for RowCursorStream {
 }
 
 /// Specialized stream for sorts on single primitive columns
-pub struct PrimitiveCursorStream<T: ArrowPrimitiveType> {
+pub struct FieldCursorStream<T: FieldArray> {
     /// The physical expressions to sort by
     sort: PhysicalSortExpr,
     /// Input streams
@@ -152,16 +151,15 @@ pub struct PrimitiveCursorStream<T: ArrowPrimitiveType> {
     phantom: PhantomData<fn(T) -> T>,
 }
 
-impl<T: ArrowPrimitiveType> std::fmt::Debug for PrimitiveCursorStream<T> {
+impl<T: FieldArray> std::fmt::Debug for FieldCursorStream<T> {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         f.debug_struct("PrimitiveCursorStream")
-            .field("data_type", &T::DATA_TYPE)
             .field("num_streams", &self.streams)
             .finish()
     }
 }
 
-impl<T: ArrowPrimitiveType> PrimitiveCursorStream<T> {
+impl<T: FieldArray> FieldCursorStream<T> {
     pub fn new(sort: PhysicalSortExpr, streams: Vec<SendableRecordBatchStream>) -> Self {
         let streams = streams.into_iter().map(|s| s.fuse()).collect();
         Self {
@@ -171,24 +169,16 @@ impl<T: ArrowPrimitiveType> PrimitiveCursorStream<T> {
         }
     }
 
-    fn convert_batch(
-        &mut self,
-        batch: &RecordBatch,
-    ) -> Result<PrimitiveCursor<T::Native>> {
+    fn convert_batch(&mut self, batch: &RecordBatch) -> Result<FieldCursor<T::Values>> {
         let value = self.sort.expr.evaluate(batch)?;
         let array = value.into_array(batch.num_rows());
-        let array = as_primitive_array::<T>(array.as_ref())?;
-
-        Ok(PrimitiveCursor::new(
-            self.sort.options,
-            array.values().clone(),
-            array.null_count(),
-        ))
+        let array = array.as_any().downcast_ref::<T>().expect("field values");
+        Ok(FieldCursor::new(self.sort.options, array))
     }
 }
 
-impl<T: ArrowPrimitiveType> PartitionedStream for PrimitiveCursorStream<T> {
-    type Output = Result<(PrimitiveCursor<T::Native>, RecordBatch)>;
+impl<T: FieldArray> PartitionedStream for FieldCursorStream<T> {
+    type Output = Result<(FieldCursor<T::Values>, RecordBatch)>;
 
     fn partitions(&self) -> usize {
         self.streams.0.len()