You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/12/15 19:56:21 UTC
[arrow-rs] branch master updated: Mark `MutableBuffer::typed_data_mut` unsafe (#1029)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new f21fa54 Mark `MutableBuffer::typed_data_mut` unsafe (#1029)
f21fa54 is described below
commit f21fa54e026afd5636cb0308704332858c76fbc7
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Wed Dec 15 14:56:16 2021 -0500
Mark `MutableBuffer::typed_data_mut` unsafe (#1029)
* Mark `MutableBuffer::typed_data_mut` unsafe
* fmt
* Mark use of `typed_data_but` as unsafe in simd kernels
---
arrow/src/buffer/mutable.rs | 21 +++++++++------
arrow/src/buffer/ops.rs | 5 +++-
arrow/src/compute/kernels/arithmetic.rs | 33 ++++++++++++++++-------
arrow/src/compute/kernels/comparison.rs | 47 +++++++++++++++++----------------
arrow/src/compute/kernels/sort.rs | 6 +++--
arrow/src/compute/kernels/take.rs | 3 ++-
parquet/src/arrow/array_reader.rs | 3 ++-
7 files changed, 73 insertions(+), 45 deletions(-)
diff --git a/arrow/src/buffer/mutable.rs b/arrow/src/buffer/mutable.rs
index 61593af..7beada9 100644
--- a/arrow/src/buffer/mutable.rs
+++ b/arrow/src/buffer/mutable.rs
@@ -273,15 +273,20 @@ impl MutableBuffer {
}
/// View this buffer asa slice of a specific type.
+ ///
/// # Safety
- /// This function must only be used when this buffer was extended with items of type `T`.
- /// Failure to do so results in undefined behavior.
- pub fn typed_data_mut<T: ArrowNativeType>(&mut self) -> &mut [T] {
- unsafe {
- let (prefix, offsets, suffix) = self.as_slice_mut().align_to_mut::<T>();
- assert!(prefix.is_empty() && suffix.is_empty());
- offsets
- }
+ ///
+ /// This function must only be used with buffers which are treated
+ /// as type `T` (e.g. extended with items of type `T`).
+ ///
+ /// # Panics
+ ///
+ /// This function panics if the underlying buffer is not aligned
+ /// correctly for type `T`.
+ pub unsafe fn typed_data_mut<T: ArrowNativeType>(&mut self) -> &mut [T] {
+ let (prefix, offsets, suffix) = self.as_slice_mut().align_to_mut::<T>();
+ assert!(prefix.is_empty() && suffix.is_empty());
+ offsets
}
/// Extends this buffer from a slice of items that can be represented in bytes, increasing its capacity if needed.
diff --git a/arrow/src/buffer/ops.rs b/arrow/src/buffer/ops.rs
index c37fd14..14d3811 100644
--- a/arrow/src/buffer/ops.rs
+++ b/arrow/src/buffer/ops.rs
@@ -168,7 +168,10 @@ where
MutableBuffer::new(ceil(len_in_bits, 8)).with_bitset(len_in_bits / 64 * 8, false);
let left_chunks = left.bit_chunks(offset_in_bits, len_in_bits);
- let result_chunks = result.typed_data_mut::<u64>().iter_mut();
+
+ // Safety: buffer is always treated as type `u64` in the code
+ // below.
+ let result_chunks = unsafe { result.typed_data_mut::<u64>().iter_mut() };
result_chunks
.zip(left_chunks.iter())
diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs
index f92888b..09d4b9f 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -57,7 +57,8 @@ where
let buffer_size = array.len() * std::mem::size_of::<T::Native>();
let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
- let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
+ // safety: result is newly created above, always written as a T below
+ let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut array_chunks = array.values().chunks_exact(lanes);
result_chunks
@@ -111,7 +112,8 @@ where
let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
- let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
+ // safety: result is newly created above, always written as a T below
+ let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut array_chunks = array.values().chunks_exact(lanes);
result_chunks
@@ -398,7 +400,8 @@ where
let buffer_size = left.len() * std::mem::size_of::<T::Native>();
let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
- let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
+ // safety: result is newly created above, always written as a T below
+ let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut left_chunks = left.values().chunks_exact(lanes);
let mut right_chunks = right.values().chunks_exact(lanes);
@@ -662,7 +665,10 @@ where
let valid_chunks = b.bit_chunks(0, left.len());
// process data in chunks of 64 elements since we also get 64 bits of validity information at a time
- let mut result_chunks = result.typed_data_mut().chunks_exact_mut(64);
+
+ // safety: result is newly created above, always written as a T below
+ let mut result_chunks =
+ unsafe { result.typed_data_mut().chunks_exact_mut(64) };
let mut left_chunks = left.values().chunks_exact(64);
let mut right_chunks = right.values().chunks_exact(64);
@@ -707,7 +713,9 @@ where
)?;
}
None => {
- let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
+ // safety: result is newly created above, always written as a T below
+ let mut result_chunks =
+ unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut left_chunks = left.values().chunks_exact(lanes);
let mut right_chunks = right.values().chunks_exact(lanes);
@@ -784,7 +792,10 @@ where
let valid_chunks = b.bit_chunks(0, left.len());
// process data in chunks of 64 elements since we also get 64 bits of validity information at a time
- let mut result_chunks = result.typed_data_mut().chunks_exact_mut(64);
+
+ // safety: result is newly created above, always written as a T below
+ let mut result_chunks =
+ unsafe { result.typed_data_mut().chunks_exact_mut(64) };
let mut left_chunks = left.values().chunks_exact(64);
let mut right_chunks = right.values().chunks_exact(64);
@@ -829,7 +840,9 @@ where
)?;
}
None => {
- let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
+ // safety: result is newly created above, always written as a T below
+ let mut result_chunks =
+ unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut left_chunks = left.values().chunks_exact(lanes);
let mut right_chunks = right.values().chunks_exact(lanes);
@@ -891,7 +904,8 @@ where
let buffer_size = array.len() * std::mem::size_of::<T::Native>();
let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
- let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
+ // safety: result is newly created above, always written as a T below
+ let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut array_chunks = array.values().chunks_exact(lanes);
result_chunks
@@ -942,7 +956,8 @@ where
let buffer_size = array.len() * std::mem::size_of::<T::Native>();
let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
- let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
+ // safety: result is newly created above, always written as a T below
+ let mut result_chunks = unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
let mut array_chunks = array.values().chunks_exact(lanes);
result_chunks
diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs
index 125be91..33644c4 100644
--- a/arrow/src/compute/kernels/comparison.rs
+++ b/arrow/src/compute/kernels/comparison.rs
@@ -921,23 +921,22 @@ where
let mut left_chunks = left.values().chunks_exact(lanes);
let mut right_chunks = right.values().chunks_exact(lanes);
+ // safety: result is newly created above, always written as a T below
+ let result_chunks = unsafe { result.typed_data_mut() };
let result_remainder = left_chunks
.borrow_mut()
.zip(right_chunks.borrow_mut())
- .fold(
- result.typed_data_mut(),
- |result_slice, (left_slice, right_slice)| {
- let simd_left = T::load(left_slice);
- let simd_right = T::load(right_slice);
- let simd_result = simd_op(simd_left, simd_right);
+ .fold(result_chunks, |result_slice, (left_slice, right_slice)| {
+ let simd_left = T::load(left_slice);
+ let simd_right = T::load(right_slice);
+ let simd_result = simd_op(simd_left, simd_right);
- let bitmask = T::mask_to_u64(&simd_result);
- let bytes = bitmask.to_le_bytes();
- result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);
+ let bitmask = T::mask_to_u64(&simd_result);
+ let bytes = bitmask.to_le_bytes();
+ result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);
- &mut result_slice[lanes / 8..]
- },
- );
+ &mut result_slice[lanes / 8..]
+ });
let left_remainder = left_chunks.remainder();
let right_remainder = right_chunks.remainder();
@@ -1005,19 +1004,21 @@ where
let mut left_chunks = left.values().chunks_exact(lanes);
let simd_right = T::init(right);
- let result_remainder = left_chunks.borrow_mut().fold(
- result.typed_data_mut(),
- |result_slice, left_slice| {
- let simd_left = T::load(left_slice);
- let simd_result = simd_op(simd_left, simd_right);
+ // safety: result is newly created above, always written as a T below
+ let result_chunks = unsafe { result.typed_data_mut() };
+ let result_remainder =
+ left_chunks
+ .borrow_mut()
+ .fold(result_chunks, |result_slice, left_slice| {
+ let simd_left = T::load(left_slice);
+ let simd_result = simd_op(simd_left, simd_right);
- let bitmask = T::mask_to_u64(&simd_result);
- let bytes = bitmask.to_le_bytes();
- result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);
+ let bitmask = T::mask_to_u64(&simd_result);
+ let bytes = bitmask.to_le_bytes();
+ result_slice[0..lanes / 8].copy_from_slice(&bytes[0..lanes / 8]);
- &mut result_slice[lanes / 8..]
- },
- );
+ &mut result_slice[lanes / 8..]
+ });
let left_remainder = left_chunks.remainder();
diff --git a/arrow/src/compute/kernels/sort.rs b/arrow/src/compute/kernels/sort.rs
index 6a72224..1046853 100644
--- a/arrow/src/compute/kernels/sort.rs
+++ b/arrow/src/compute/kernels/sort.rs
@@ -471,7 +471,8 @@ fn sort_boolean(
let mut result = MutableBuffer::new(result_capacity);
// sets len to capacity so we can access the whole buffer as a typed slice
result.resize(result_capacity, 0);
- let result_slice: &mut [u32] = result.typed_data_mut();
+ // Safety: the buffer is always treated as `u32` in the code below
+ let result_slice: &mut [u32] = unsafe { result.typed_data_mut() };
if options.nulls_first {
let size = nulls_len.min(len);
@@ -559,7 +560,8 @@ where
let mut result = MutableBuffer::new(result_capacity);
// sets len to capacity so we can access the whole buffer as a typed slice
result.resize(result_capacity, 0);
- let result_slice: &mut [u32] = result.typed_data_mut();
+ // Safety: the buffer is always treated as `u32` in the code below
+ let result_slice: &mut [u32] = unsafe { result.typed_data_mut() };
if options.nulls_first {
let size = nulls_len.min(len);
diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs
index 9fe00ea..63df3ab 100644
--- a/arrow/src/compute/kernels/take.rs
+++ b/arrow/src/compute/kernels/take.rs
@@ -632,7 +632,8 @@ where
let bytes_offset = (data_len + 1) * std::mem::size_of::<OffsetSize>();
let mut offsets_buffer = MutableBuffer::from_len_zeroed(bytes_offset);
- let offsets = offsets_buffer.typed_data_mut();
+ // Safety: the buffer is always treated as as a type of `OffsetSize` in the code below
+ let offsets = unsafe { offsets_buffer.typed_data_mut() };
let mut values = MutableBuffer::new(0);
let mut length_so_far = OffsetSize::zero();
offsets[0] = length_so_far;
diff --git a/parquet/src/arrow/array_reader.rs b/parquet/src/arrow/array_reader.rs
index c3abda0..ef8cf70 100644
--- a/parquet/src/arrow/array_reader.rs
+++ b/parquet/src/arrow/array_reader.rs
@@ -1140,7 +1140,8 @@ impl ArrayReader for StructArrayReader {
let mut def_level_data_buffer = MutableBuffer::new(buffer_size);
def_level_data_buffer.resize(buffer_size, 0);
- let def_level_data = def_level_data_buffer.typed_data_mut();
+ // Safety: the buffer is always treated as `u16` in the code below
+ let def_level_data = unsafe { def_level_data_buffer.typed_data_mut() };
def_level_data
.iter_mut()