You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/12/15 19:56:21 UTC

[arrow-rs] branch master updated: Mark `MutableBuffer::typed_data_mut` unsafe (#1029)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new f21fa54  Mark `MutableBuffer::typed_data_mut` unsafe (#1029)
f21fa54 is described below

commit f21fa54e026afd5636cb0308704332858c76fbc7
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Wed Dec 15 14:56:16 2021 -0500

    Mark `MutableBuffer::typed_data_mut` unsafe (#1029)
    
    * Mark `MutableBuffer::typed_data_mut` unsafe
    
    * fmt
    
    * Mark use of `typed_data_but` as unsafe in simd kernels
---
 arrow/src/buffer/mutable.rs             | 21 +++++++++------
 arrow/src/buffer/ops.rs                 |  5 +++-
 arrow/src/compute/kernels/arithmetic.rs | 33 ++++++++++++++++-------
 arrow/src/compute/kernels/comparison.rs | 47 +++++++++++++++++----------------
 arrow/src/compute/kernels/sort.rs       |  6 +++--
 arrow/src/compute/kernels/take.rs       |  3 ++-
 parquet/src/arrow/array_reader.rs       |  3 ++-
 7 files changed, 73 insertions(+), 45 deletions(-)

diff --git a/arrow/src/buffer/mutable.rs b/arrow/src/buffer/mutable.rs
index 61593af..7beada9 100644
--- a/arrow/src/buffer/mutable.rs
+++ b/arrow/src/buffer/mutable.rs
@@ -273,15 +273,20 @@ impl MutableBuffer {
     }
 
     /// View this buffer asa slice of a specific type.
+    ///
     /// # Safety
-    /// This function must only be used when this buffer was extended with items of type `T`.
-    /// Failure to do so results in undefined behavior.
-    pub fn typed_data_mut<T: ArrowNativeType>(&mut self) -> &mut [T] {
-        unsafe {
-            let (prefix, offsets, suffix) = self.as_slice_mut().align_to_mut::<T>();
-            assert!(prefix.is_empty() && suffix.is_empty());
-            offsets
-        }
+    ///
+    /// This function must only be used with buffers which are treated
+    /// as type `T` (e.g.  extended with items of type `T`).
+    ///
+    /// # Panics
+    ///
+    /// This function panics if the underlying buffer is not aligned
+    /// correctly for type `T`.
+    pub unsafe fn typed_data_mut<T: ArrowNativeType>(&mut self) -> &mut [T] {
+        let (prefix, offsets, suffix) = self.as_slice_mut().align_to_mut::<T>();
+        assert!(prefix.is_empty() && suffix.is_empty());
+        offsets
     }
 
     /// Extends this buffer from a slice of items that can be represented in bytes, increasing its capacity if needed.
diff --git a/arrow/src/buffer/ops.rs b/arrow/src/buffer/ops.rs
index c37fd14..14d3811 100644
--- a/arrow/src/buffer/ops.rs
+++ b/arrow/src/buffer/ops.rs
@@ -168,7 +168,10 @@ where
         MutableBuffer::new(ceil(len_in_bits, 8)).with_bitset(len_in_bits / 64 * 8, false);
 
     let left_chunks = left.bit_chunks(offset_in_bits, len_in_bits);
-    let result_chunks = result.typed_data_mut::<u64>().iter_mut();
+
+    // Safety: buffer is always treated as type `u64` in the code
+    // below.
+    let result_chunks = unsafe { result.typed_data_mut::<u64>().iter_mut() };
 
     result_chunks
         .zip(left_chunks.iter())
diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs
index f92888b..09d4b9f 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -57,7 +57,8 @@ where
     let buffer_size = array.len() * std::mem::size_of::<T::Native>();
     let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
 
-    let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
+    // safety: result is newly created above, always written as a T below
+    let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
     let mut array_chunks = array.values().chunks_exact(lanes);
 
     result_chunks
@@ -111,7 +112,8 @@ where
 
     let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
 
-    let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
+    // safety: result is newly created above, always written as a T below
+    let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
     let mut array_chunks = array.values().chunks_exact(lanes);
 
     result_chunks
@@ -398,7 +400,8 @@ where
     let buffer_size = left.len() * std::mem::size_of::<T::Native>();
     let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
 
-    let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
+    // safety: result is newly created above, always written as a T below
+    let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
     let mut left_chunks = left.values().chunks_exact(lanes);
     let mut right_chunks = right.values().chunks_exact(lanes);
 
@@ -662,7 +665,10 @@ where
             let valid_chunks = b.bit_chunks(0, left.len());
 
             // process data in chunks of 64 elements since we also get 64 bits of validity information at a time
-            let mut result_chunks = result.typed_data_mut().chunks_exact_mut(64);
+
+            // safety: result is newly created above, always written as a T below
+            let mut result_chunks =
+                unsafe { result.typed_data_mut().chunks_exact_mut(64) };
             let mut left_chunks = left.values().chunks_exact(64);
             let mut right_chunks = right.values().chunks_exact(64);
 
@@ -707,7 +713,9 @@ where
             )?;
         }
         None => {
-            let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
+            // safety: result is newly created above, always written as a T below
+            let mut result_chunks =
+                unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
             let mut left_chunks = left.values().chunks_exact(lanes);
             let mut right_chunks = right.values().chunks_exact(lanes);
 
@@ -784,7 +792,10 @@ where
             let valid_chunks = b.bit_chunks(0, left.len());
 
             // process data in chunks of 64 elements since we also get 64 bits of validity information at a time
-            let mut result_chunks = result.typed_data_mut().chunks_exact_mut(64);
+
+            // safety: result is newly created above, always written as a T below
+            let mut result_chunks =
+                unsafe { result.typed_data_mut().chunks_exact_mut(64) };
             let mut left_chunks = left.values().chunks_exact(64);
             let mut right_chunks = right.values().chunks_exact(64);
 
@@ -829,7 +840,9 @@ where
             )?;
         }
         None => {
-            let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
+            // safety: result is newly created above, always written as a T below
+            let mut result_chunks =
+                unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
             let mut left_chunks = left.values().chunks_exact(lanes);
             let mut right_chunks = right.values().chunks_exact(lanes);
 
@@ -891,7 +904,8 @@ where
     let buffer_size = array.len() * std::mem::size_of::<T::Native>();
     let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
 
-    let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
+    // safety: result is newly created above, always written as a T below
+    let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
     let mut array_chunks = array.values().chunks_exact(lanes);
 
     result_chunks
@@ -942,7 +956,8 @@ where
     let buffer_size = array.len() * std::mem::size_of::<T::Native>();
     let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
 
-    let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
+    // safety: result is newly created above, always written as a T below
+    let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
     let mut array_chunks = array.values().chunks_exact(lanes);
 
     result_chunks
diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs
index 125be91..33644c4 100644
--- a/arrow/src/compute/kernels/comparison.rs
+++ b/arrow/src/compute/kernels/comparison.rs
@@ -921,23 +921,22 @@ where
     let mut left_chunks = left.values().chunks_exact(lanes);
     let mut right_chunks = right.values().chunks_exact(lanes);
 
+    // safety: result is newly created above, always written as a T below
+    let result_chunks = unsafe { result.typed_data_mut() };
     let result_remainder = left_chunks
         .borrow_mut()
         .zip(right_chunks.borrow_mut())
-        .fold(
-            result.typed_data_mut(),
-            |result_slice, (left_slice, right_slice)| {
-                let simd_left = T::load(left_slice);
-                let simd_right = T::load(right_slice);
-                let simd_result = simd_op(simd_left, simd_right);
+        .fold(result_chunks, |result_slice, (left_slice, right_slice)| {
+            let simd_left = T::load(left_slice);
+            let simd_right = T::load(right_slice);
+            let simd_result = simd_op(simd_left, simd_right);
 
-                let bitmask = T::mask_to_u64(&simd_result);
-                let bytes = bitmask.to_le_bytes();
-                result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);
+            let bitmask = T::mask_to_u64(&simd_result);
+            let bytes = bitmask.to_le_bytes();
+            result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);
 
-                &mut result_slice[lanes / 8..]
-            },
-        );
+            &mut result_slice[lanes / 8..]
+        });
 
     let left_remainder = left_chunks.remainder();
     let right_remainder = right_chunks.remainder();
@@ -1005,19 +1004,21 @@ where
     let mut left_chunks = left.values().chunks_exact(lanes);
     let simd_right = T::init(right);
 
-    let result_remainder = left_chunks.borrow_mut().fold(
-        result.typed_data_mut(),
-        |result_slice, left_slice| {
-            let simd_left = T::load(left_slice);
-            let simd_result = simd_op(simd_left, simd_right);
+    // safety: result is newly created above, always written as a T below
+    let result_chunks = unsafe { result.typed_data_mut() };
+    let result_remainder =
+        left_chunks
+            .borrow_mut()
+            .fold(result_chunks, |result_slice, left_slice| {
+                let simd_left = T::load(left_slice);
+                let simd_result = simd_op(simd_left, simd_right);
 
-            let bitmask = T::mask_to_u64(&simd_result);
-            let bytes = bitmask.to_le_bytes();
-            result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);
+                let bitmask = T::mask_to_u64(&simd_result);
+                let bytes = bitmask.to_le_bytes();
+                result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);
 
-            &mut result_slice[lanes / 8..]
-        },
-    );
+                &mut result_slice[lanes / 8..]
+            });
 
     let left_remainder = left_chunks.remainder();
 
diff --git a/arrow/src/compute/kernels/sort.rs b/arrow/src/compute/kernels/sort.rs
index 6a72224..1046853 100644
--- a/arrow/src/compute/kernels/sort.rs
+++ b/arrow/src/compute/kernels/sort.rs
@@ -471,7 +471,8 @@ fn sort_boolean(
     let mut result = MutableBuffer::new(result_capacity);
     // sets len to capacity so we can access the whole buffer as a typed slice
     result.resize(result_capacity, 0);
-    let result_slice: &mut [u32] = result.typed_data_mut();
+    // Safety: the buffer is always treated as `u32` in the code below
+    let result_slice: &mut [u32] = unsafe { result.typed_data_mut() };
 
     if options.nulls_first {
         let size = nulls_len.min(len);
@@ -559,7 +560,8 @@ where
     let mut result = MutableBuffer::new(result_capacity);
     // sets len to capacity so we can access the whole buffer as a typed slice
     result.resize(result_capacity, 0);
-    let result_slice: &mut [u32] = result.typed_data_mut();
+    // Safety: the buffer is always treated as `u32` in the code below
+    let result_slice: &mut [u32] = unsafe { result.typed_data_mut() };
 
     if options.nulls_first {
         let size = nulls_len.min(len);
diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs
index 9fe00ea..63df3ab 100644
--- a/arrow/src/compute/kernels/take.rs
+++ b/arrow/src/compute/kernels/take.rs
@@ -632,7 +632,8 @@ where
     let bytes_offset = (data_len + 1) * std::mem::size_of::<OffsetSize>();
     let mut offsets_buffer = MutableBuffer::from_len_zeroed(bytes_offset);
 
-    let offsets = offsets_buffer.typed_data_mut();
+    // Safety: the buffer is always treated as as a type of `OffsetSize` in the code below
+    let offsets = unsafe { offsets_buffer.typed_data_mut() };
     let mut values = MutableBuffer::new(0);
     let mut length_so_far = OffsetSize::zero();
     offsets[0] = length_so_far;
diff --git a/parquet/src/arrow/array_reader.rs b/parquet/src/arrow/array_reader.rs
index c3abda0..ef8cf70 100644
--- a/parquet/src/arrow/array_reader.rs
+++ b/parquet/src/arrow/array_reader.rs
@@ -1140,7 +1140,8 @@ impl ArrayReader for StructArrayReader {
         let mut def_level_data_buffer = MutableBuffer::new(buffer_size);
         def_level_data_buffer.resize(buffer_size, 0);
 
-        let def_level_data = def_level_data_buffer.typed_data_mut();
+        // Safety: the buffer is always treated as `u16` in the code below
+        let def_level_data = unsafe { def_level_data_buffer.typed_data_mut() };
 
         def_level_data
             .iter_mut()