You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "tustvold (via GitHub)" <gi...@apache.org> on 2023/07/03 09:44:14 UTC

[GitHub] [arrow-rs] tustvold commented on a diff in pull request #4473: Improve in-place primitive sorts by 13-67%

tustvold commented on code in PR #4473:
URL: https://github.com/apache/arrow-rs/pull/4473#discussion_r1250580989


##########
arrow-ord/src/sort.rs:
##########
@@ -57,11 +60,211 @@ pub fn sort(
     values: &dyn Array,
     options: Option<SortOptions>,
 ) -> Result<ArrayRef, ArrowError> {
-    if let DataType::RunEndEncoded(_, _) = values.data_type() {
-        return sort_run(values, options, None);
+    match values.data_type() {
+        DataType::Int8 => sort_native_type::<Int8Type, i8>(values, options),
+        DataType::Int16 => sort_native_type::<Int16Type, i16>(values, options),
+        DataType::Int32 => sort_native_type::<Int32Type, i32>(values, options),
+        DataType::Int64 => sort_native_type::<Int64Type, i64>(values, options),
+        DataType::UInt8 => sort_native_type::<UInt8Type, u8>(values, options),
+        DataType::UInt16 => sort_native_type::<UInt16Type, u16>(values, options),
+        DataType::UInt32 => sort_native_type::<UInt32Type, u32>(values, options),
+        DataType::UInt64 => sort_native_type::<UInt64Type, u64>(values, options),
+        DataType::Float32 => sort_native_type::<Float32Type, f32>(values, options),
+        DataType::Float64 => sort_native_type::<Float64Type, f64>(values, options),
+        DataType::Date32 => sort_native_type::<Date32Type, i32>(values, options),
+        DataType::Date64 => sort_native_type::<Date64Type, i64>(values, options),
+        DataType::Time32(TimeUnit::Second) => {
+            sort_native_type::<Time32SecondType, i32>(values, options)
+        }
+        DataType::Time32(TimeUnit::Millisecond) => {
+            sort_native_type::<Time32MillisecondType, i32>(values, options)
+        }
+        DataType::Time64(TimeUnit::Microsecond) => {
+            sort_native_type::<Time64MicrosecondType, i64>(values, options)
+        }
+        DataType::Time64(TimeUnit::Nanosecond) => {
+            sort_native_type::<Time64NanosecondType, i64>(values, options)
+        }
+        DataType::Timestamp(TimeUnit::Second, _) => {
+            sort_native_type::<TimestampSecondType, i64>(values, options)
+        }
+        DataType::Timestamp(TimeUnit::Millisecond, _) => {
+            sort_native_type::<TimestampMillisecondType, i64>(values, options)
+        }
+        DataType::Timestamp(TimeUnit::Microsecond, _) => {
+            sort_native_type::<TimestampMicrosecondType, i64>(values, options)
+        }
+        DataType::Timestamp(TimeUnit::Nanosecond, _) => {
+            sort_native_type::<TimestampNanosecondType, i64>(values, options)
+        }
+        DataType::Interval(IntervalUnit::YearMonth) => {
+            sort_native_type::<IntervalYearMonthType, i32>(values, options)
+        }
+        DataType::Interval(IntervalUnit::DayTime) => {
+            sort_native_type::<IntervalDayTimeType, i64>(values, options)
+        }
+        DataType::Interval(IntervalUnit::MonthDayNano) => {
+            sort_native_type::<IntervalMonthDayNanoType, i128>(values, options)
+        }
+        DataType::Duration(TimeUnit::Second) => {
+            sort_native_type::<DurationSecondType, i64>(values, options)
+        }
+        DataType::Duration(TimeUnit::Millisecond) => {
+            sort_native_type::<DurationMillisecondType, i64>(values, options)
+        }
+        DataType::Duration(TimeUnit::Microsecond) => {
+            sort_native_type::<DurationMicrosecondType, i64>(values, options)
+        }
+        DataType::Duration(TimeUnit::Nanosecond) => {
+            sort_native_type::<DurationNanosecondType, i64>(values, options)
+        }
+        DataType::RunEndEncoded(_, _) => sort_run(values, options, None),
+        _ => {
+            let indices = sort_to_indices(values, options, None)?;
+            take(values, &indices, None)
+        }
     }
-    let indices = sort_to_indices(values, options, None)?;
-    take(values, &indices, None)
+}
+
+fn compress_store<U>(input: *const U, mut output: *mut U, mask: u8) -> isize
+where
+    U: ArrowNativeType,
+{
+    let mut offset = 0;
+    if mask != 0 {
+        for i in 0..8 {
+            if (mask & (1 << i)) != 0 {
+                // This is safe since a valid bit i.e bit set to 1 indicates a valid value
+                unsafe {
+                    *output = *input.offset(i);
+                    offset += 1;
+                    output = output.offset(1);
+                }
+            }
+        }
+    }
+    offset
+}
+
+fn create_null_buffer(
+    valid_count: usize,
+    nulls_count: usize,
+    length: usize,
+    sort_options: SortOptions,
+) -> Option<Buffer> {
+    let null_capacity = (length / 8) + (length % 8 != 0) as usize;
+    let mut mutable_null_buffer = MutableBuffer::new(null_capacity * 8);
+    mutable_null_buffer.resize(null_capacity, 0);
+
+    let mutable_null_buffer_slice = mutable_null_buffer.as_slice_mut();
+
+    if valid_count > 0 {
+        let mut count = valid_count;
+        let mut index = 0;
+        if sort_options.nulls_first {
+            let remaining_nulls = nulls_count % 8;
+            index = nulls_count / 8;
+
+            if remaining_nulls != 0 {
+                let valid_values_count = min(8 - remaining_nulls, valid_count);
+                mutable_null_buffer_slice[index] =
+                    ((1 << valid_values_count) - 1) << remaining_nulls;
+                count -= valid_values_count;
+                index += 1;
+            }
+        }
+        while count >= 8 {
+            mutable_null_buffer_slice[index] = u8::MAX;
+            index += 1;
+            count -= 8;
+        }
+        if count != 0 {
+            mutable_null_buffer_slice[index] = (1 << count) - 1;
+        }
+    }
+
+    Some(mutable_null_buffer.into())
+}
+
+fn sort_native_type<T, U>(
+    values: &dyn Array,
+    options: Option<SortOptions>,
+) -> Result<ArrayRef, ArrowError>
+where
+    T: ArrowPrimitiveType,
+    U: ArrowNativeTypeOp,
+{
+    let sort_options = options.unwrap_or_default();
+    let values = values.as_primitive::<T>();
+
+    let result_capacity = values.len() * std::mem::size_of::<U>();
+    let mut mutable_buffer = MutableBuffer::new(result_capacity);
+    mutable_buffer.resize(result_capacity, 0);
+    let mutable_slice: &mut [U] = mutable_buffer.typed_data_mut();
+
+    let array_data = values.to_data();
+    let input_values: &[U] = array_data.buffer(0);
+
+    let mut null_bit_buffer = None;
+
+    let nulls_count = values.null_count();
+    let valid_count = values.len() - nulls_count;
+
+    if values.null_count() > 0 {
+        let nulls = array_data.nulls().unwrap();
+        let null_buffer = nulls.buffer().as_slice();
+
+        let mut mutable_slice_ptr = mutable_slice.as_mut_ptr();
+        let mut input_values_ptr = input_values.as_ptr();
+
+        if sort_options.nulls_first {
+            // This is safe since the offset in in bounds
+            unsafe {
+                mutable_slice_ptr = mutable_slice_ptr.add(values.null_count());
+            }
+        }
+
+        // This is safe since we are in bounds
+        let values_slice =
+            unsafe { slice::from_raw_parts_mut(mutable_slice_ptr, valid_count) };
+
+        for mask in null_buffer {
+            let written_count =
+                compress_store::<U>(input_values_ptr, mutable_slice_ptr, *mask);
+            // This is safe as the offset increments are within bounds
+            unsafe {
+                input_values_ptr = input_values_ptr.offset(8);
+                mutable_slice_ptr = mutable_slice_ptr.offset(written_count);
+            }
+        }
+
+        values_slice.sort_unstable_by(|a, b| a.compare(*b));
+        if sort_options.descending {
+            values_slice.reverse();
+        }
+
+        null_bit_buffer =
+            create_null_buffer(valid_count, nulls_count, values.len(), sort_options);
+    } else {
+        mutable_slice.copy_from_slice(input_values);
+        mutable_slice.sort_unstable_by(|a, b| a.compare(*b));
+        if sort_options.descending {
+            mutable_slice.reverse();
+        }
+    }
+    // This is safe since data types match

Review Comment:
   Perhaps we could use `PrimitiveArray::new().with_data_type()` instead



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org