You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/09/21 16:29:57 UTC

[arrow-rs] 01/01: chore: Reduce the amount of code generated by monomorphization (#715)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch cherry_pick_5c3ed612
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git

commit fd02abada7395872b230021fd808ff72e7cb9b49
Author: Markus Westerlind <ma...@distilnetworks.com>
AuthorDate: Mon Sep 13 18:55:56 2021 +0200

    chore: Reduce the amount of code generated by monomorphization (#715)
    
    * chore: Reduce the number of instantiations of take* (-3%)
    
    Many types have the same native type, so simplifying these functions to
    work directly with native types reduces the number of instantiations.
    
    Reduces the number of llvm lines generated by ~3%
    
    * chore: Shrink try_from_trusted_len_iter (-0.5%)
    
    * chore: Make the inner take_ functions less generic (-3.5%)
    
    * chore: Extract the array sorter (-1%)
---
 arrow/src/buffer/mutable.rs       |  22 +++++---
 arrow/src/compute/kernels/take.rs | 107 +++++++++++++++++++++++++-------------
 2 files changed, 86 insertions(+), 43 deletions(-)

diff --git a/arrow/src/buffer/mutable.rs b/arrow/src/buffer/mutable.rs
index 7d336e0..d83997a 100644
--- a/arrow/src/buffer/mutable.rs
+++ b/arrow/src/buffer/mutable.rs
@@ -530,12 +530,22 @@ impl MutableBuffer {
             std::ptr::write(dst, item?);
             dst = dst.add(1);
         }
-        assert_eq!(
-            dst.offset_from(buffer.data.as_ptr() as *mut T) as usize,
-            upper,
-            "Trusted iterator length was not accurately reported"
-        );
-        buffer.len = len;
+        // try_from_trusted_len_iter is instantiated a lot, so we extract part of it into a less
+        // generic method to reduce compile time
+        unsafe fn finalize_buffer<T>(
+            dst: *mut T,
+            buffer: &mut MutableBuffer,
+            upper: usize,
+            len: usize,
+        ) {
+            assert_eq!(
+                dst.offset_from(buffer.data.as_ptr() as *mut T) as usize,
+                upper,
+                "Trusted iterator length was not accurately reported"
+            );
+            buffer.len = len;
+        }
+        finalize_buffer(dst, &mut buffer, upper, len);
         Ok(buffer)
     }
 }
diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs
index 225f263..7147972 100644
--- a/arrow/src/compute/kernels/take.rs
+++ b/arrow/src/compute/kernels/take.rs
@@ -302,20 +302,17 @@ impl Default for TakeOptions {
 }
 
 #[inline(always)]
-fn maybe_usize<I: ArrowPrimitiveType>(index: I::Native) -> Result<usize> {
+fn maybe_usize<I: ArrowNativeType>(index: I) -> Result<usize> {
     index
         .to_usize()
         .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))
 }
 
 // take implementation when neither values nor indices contain nulls
-fn take_no_nulls<T, I>(
-    values: &[T::Native],
-    indices: &[I::Native],
-) -> Result<(Buffer, Option<Buffer>)>
+fn take_no_nulls<T, I>(values: &[T], indices: &[I]) -> Result<(Buffer, Option<Buffer>)>
 where
-    T: ArrowPrimitiveType,
-    I: ArrowNumericType,
+    T: ArrowNativeType,
+    I: ArrowNativeType,
 {
     let values = indices
         .iter()
@@ -329,27 +326,36 @@ where
 // take implementation when only values contain nulls
 fn take_values_nulls<T, I>(
     values: &PrimitiveArray<T>,
-    indices: &[I::Native],
+    indices: &[I],
 ) -> Result<(Buffer, Option<Buffer>)>
 where
     T: ArrowPrimitiveType,
-    I: ArrowNumericType,
-    I::Native: ToPrimitive,
+    I: ArrowNativeType,
+{
+    take_values_nulls_inner(values.data(), values.values(), indices)
+}
+
+fn take_values_nulls_inner<T, I>(
+    values_data: &ArrayData,
+    values: &[T],
+    indices: &[I],
+) -> Result<(Buffer, Option<Buffer>)>
+where
+    T: ArrowNativeType,
+    I: ArrowNativeType,
 {
     let num_bytes = bit_util::ceil(indices.len(), 8);
     let mut nulls = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
     let null_slice = nulls.as_slice_mut();
     let mut null_count = 0;
 
-    let values_values = values.values();
-
     let values = indices.iter().enumerate().map(|(i, index)| {
         let index = maybe_usize::<I>(*index)?;
-        if values.is_null(index) {
+        if values_data.is_null(index) {
             null_count += 1;
             bit_util::unset_bit(null_slice, i);
         }
-        Result::Ok(values_values[index])
+        Result::Ok(values[index])
     });
     // Soundness: `slice.map` is `TrustedLen`.
     let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };
@@ -366,21 +372,33 @@ where
 
 // take implementation when only indices contain nulls
 fn take_indices_nulls<T, I>(
-    values: &[T::Native],
+    values: &[T],
     indices: &PrimitiveArray<I>,
 ) -> Result<(Buffer, Option<Buffer>)>
 where
-    T: ArrowPrimitiveType,
+    T: ArrowNativeType,
     I: ArrowNumericType,
     I::Native: ToPrimitive,
 {
-    let values = indices.values().iter().map(|index| {
+    take_indices_nulls_inner(values, indices.values(), indices.data())
+}
+
+fn take_indices_nulls_inner<T, I>(
+    values: &[T],
+    indices: &[I],
+    indices_data: &ArrayData,
+) -> Result<(Buffer, Option<Buffer>)>
+where
+    T: ArrowNativeType,
+    I: ArrowNativeType,
+{
+    let values = indices.iter().map(|index| {
         let index = maybe_usize::<I>(*index)?;
         Result::Ok(match values.get(index) {
             Some(value) => *value,
             None => {
-                if indices.is_null(index) {
-                    T::Native::default()
+                if indices_data.is_null(index) {
+                    T::default()
                 } else {
                     panic!("Out-of-bounds index {}", index)
                 }
@@ -393,10 +411,9 @@ where
 
     Ok((
         buffer,
-        indices
-            .data_ref()
+        indices_data
             .null_buffer()
-            .map(|b| b.bit_slice(indices.offset(), indices.len())),
+            .map(|b| b.bit_slice(indices_data.offset(), indices.len())),
     ))
 }
 
@@ -410,25 +427,41 @@ where
     I: ArrowNumericType,
     I::Native: ToPrimitive,
 {
+    take_values_indices_nulls_inner(
+        values.values(),
+        values.data(),
+        indices.values(),
+        indices.data(),
+    )
+}
+
+fn take_values_indices_nulls_inner<T, I>(
+    values: &[T],
+    values_data: &ArrayData,
+    indices: &[I],
+    indices_data: &ArrayData,
+) -> Result<(Buffer, Option<Buffer>)>
+where
+    T: ArrowNativeType,
+    I: ArrowNativeType,
+{
     let num_bytes = bit_util::ceil(indices.len(), 8);
     let mut nulls = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
     let null_slice = nulls.as_slice_mut();
     let mut null_count = 0;
 
-    let values_values = values.values();
-    let values = indices.iter().enumerate().map(|(i, index)| match index {
-        Some(index) => {
+    let values = indices.iter().enumerate().map(|(i, &index)| {
+        if indices_data.is_null(i) {
+            null_count += 1;
+            bit_util::unset_bit(null_slice, i);
+            Ok(T::default())
+        } else {
             let index = maybe_usize::<I>(index)?;
-            if values.is_null(index) {
+            if values_data.is_null(index) {
                 null_count += 1;
                 bit_util::unset_bit(null_slice, i);
             }
-            Result::Ok(values_values[index])
-        }
-        None => {
-            null_count += 1;
-            bit_util::unset_bit(null_slice, i);
-            Ok(T::Native::default())
+            Result::Ok(values[index])
         }
     });
     // Soundness: `slice.map` is `TrustedLen`.
@@ -471,17 +504,17 @@ where
         (false, false) => {
             // * no nulls
             // * all `indices.values()` are valid
-            take_no_nulls::<T, I>(values.values(), indices.values())?
+            take_no_nulls::<T::Native, I::Native>(values.values(), indices.values())?
         }
         (true, false) => {
             // * nulls come from `values` alone
             // * all `indices.values()` are valid
-            take_values_nulls::<T, I>(values, indices.values())?
+            take_values_nulls::<T, I::Native>(values, indices.values())?
         }
         (false, true) => {
             // in this branch it is unsound to read and use `index.values()`,
             // as doing so is UB when they come from a null slot.
-            take_indices_nulls::<T, I>(values.values(), indices)?
+            take_indices_nulls::<T::Native, I>(values.values(), indices)?
         }
         (true, true) => {
             // in this branch it is unsound to read and use `index.values()`,
@@ -795,7 +828,7 @@ where
         .values()
         .iter()
         .map(|idx| {
-            let idx = maybe_usize::<IndexType>(*idx)?;
+            let idx = maybe_usize::<IndexType::Native>(*idx)?;
             if data_ref.is_valid(idx) {
                 Ok(Some(values.value(idx)))
             } else {
@@ -821,7 +854,7 @@ where
         .values()
         .iter()
         .map(|idx| {
-            let idx = maybe_usize::<IndexType>(*idx)?;
+            let idx = maybe_usize::<IndexType::Native>(*idx)?;
             if data_ref.is_valid(idx) {
                 Ok(Some(values.value(idx)))
             } else {