You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/10/03 14:59:48 UTC
[arrow-rs] branch master updated: Implement ArrowNumericType for Float16Type (#2810)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 9c1748f9c Implement ArrowNumericType for Float16Type (#2810)
9c1748f9c is described below
commit 9c1748f9cb6a125e18e64bd5da17cc1782a4b2a5
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Mon Oct 3 15:59:42 2022 +0100
Implement ArrowNumericType for Float16Type (#2810)
* Implement ArrowNumericType for Float16Type
* Remove unnecessary safety comments
---
arrow/src/compute/kernels/arithmetic.rs | 31 ++++++++---
arrow/src/compute/kernels/comparison.rs | 7 +--
arrow/src/datatypes/numeric.rs | 96 +++++++++++++++++++++++++++++++++
3 files changed, 123 insertions(+), 11 deletions(-)
diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs
index 1c28c9895..b2e95ad5e 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -332,9 +332,7 @@ where
// process data in chunks of 64 elements since we also get 64 bits of validity information at a time
- // safety: result is newly created above, always written as a T below
- let mut result_chunks =
- unsafe { result.typed_data_mut().chunks_exact_mut(64) };
+ let mut result_chunks = result.typed_data_mut().chunks_exact_mut(64);
let mut left_chunks = left.values().chunks_exact(64);
let mut right_chunks = right.values().chunks_exact(64);
@@ -380,9 +378,7 @@ where
)?;
}
None => {
- // safety: result is newly created above, always written as a T below
- let mut result_chunks =
- unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
+ let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
let mut left_chunks = left.values().chunks_exact(lanes);
let mut right_chunks = right.values().chunks_exact(lanes);
@@ -1611,6 +1607,7 @@ mod tests {
use crate::array::Int32Array;
use crate::datatypes::{Date64Type, Int32Type, Int8Type};
use chrono::NaiveDate;
+ use half::f16;
#[test]
fn test_primitive_array_add() {
@@ -2898,4 +2895,26 @@ mod tests {
let division_by_zero = divide_scalar_opt_dyn::<Int32Type>(&a, 0);
assert_eq!(&expected, &division_by_zero.unwrap());
}
+
+ #[test]
+ fn test_sum_f16() {
+ let a = Float16Array::from_iter_values([
+ f16::from_f32(0.1),
+ f16::from_f32(0.2),
+ f16::from_f32(1.5),
+ f16::from_f32(-0.1),
+ ]);
+ let b = Float16Array::from_iter_values([
+ f16::from_f32(5.1),
+ f16::from_f32(6.2),
+ f16::from_f32(-1.),
+ f16::from_f32(-2.1),
+ ]);
+ let expected = Float16Array::from_iter_values(
+ a.values().iter().zip(b.values()).map(|(a, b)| a + b),
+ );
+
+ let c = add(&a, &b).unwrap();
+ assert_eq!(c, expected);
+ }
}
diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs
index 49aecfb67..1ea433150 100644
--- a/arrow/src/compute/kernels/comparison.rs
+++ b/arrow/src/compute/kernels/comparison.rs
@@ -1792,7 +1792,6 @@ where
.iter()
.map(|key| {
key.map(|key| unsafe {
- // safety lengths were verified above
let key = key.as_usize();
dict_comparison.value_unchecked(key)
})
@@ -1845,8 +1844,7 @@ where
let mut left_chunks = left.values().chunks_exact(CHUNK_SIZE);
let mut right_chunks = right.values().chunks_exact(CHUNK_SIZE);
- // safety: result is newly created above, always written as a T below
- let result_chunks = unsafe { result.typed_data_mut() };
+ let result_chunks = result.typed_data_mut();
let result_remainder = left_chunks
.borrow_mut()
.zip(right_chunks.borrow_mut())
@@ -1937,8 +1935,7 @@ where
let mut left_chunks = left.values().chunks_exact(CHUNK_SIZE);
let simd_right = T::init(right);
- // safety: result is newly created above, always written as a T below
- let result_chunks = unsafe { result.typed_data_mut() };
+ let result_chunks = result.typed_data_mut();
let result_remainder =
left_chunks
.borrow_mut()
diff --git a/arrow/src/datatypes/numeric.rs b/arrow/src/datatypes/numeric.rs
index b8fa87197..e74764d4c 100644
--- a/arrow/src/datatypes/numeric.rs
+++ b/arrow/src/datatypes/numeric.rs
@@ -366,6 +366,102 @@ make_numeric_type!(DurationMillisecondType, i64, i64x8, m64x8);
make_numeric_type!(DurationMicrosecondType, i64, i64x8, m64x8);
make_numeric_type!(DurationNanosecondType, i64, i64x8, m64x8);
+#[cfg(not(feature = "simd"))]
+impl ArrowNumericType for Float16Type {}
+
+#[cfg(feature = "simd")]
+impl ArrowNumericType for Float16Type {
+ type Simd = <Float32Type as ArrowNumericType>::Simd;
+ type SimdMask = <Float32Type as ArrowNumericType>::SimdMask;
+
+ fn lanes() -> usize {
+ Float32Type::lanes()
+ }
+
+ fn init(value: Self::Native) -> Self::Simd {
+ Float32Type::init(value.to_f32())
+ }
+
+ fn load(slice: &[Self::Native]) -> Self::Simd {
+ let mut s = [0_f32; Self::Simd::lanes()];
+ s.iter_mut().zip(slice).for_each(|(o, a)| *o = a.to_f32());
+ Float32Type::load(&s)
+ }
+
+ fn mask_init(value: bool) -> Self::SimdMask {
+ Float32Type::mask_init(value)
+ }
+
+ fn mask_from_u64(mask: u64) -> Self::SimdMask {
+ Float32Type::mask_from_u64(mask)
+ }
+
+ fn mask_to_u64(mask: &Self::SimdMask) -> u64 {
+ Float32Type::mask_to_u64(mask)
+ }
+
+ fn mask_get(mask: &Self::SimdMask, idx: usize) -> bool {
+ Float32Type::mask_get(mask, idx)
+ }
+
+ fn mask_set(mask: Self::SimdMask, idx: usize, value: bool) -> Self::SimdMask {
+ Float32Type::mask_set(mask, idx, value)
+ }
+
+ fn mask_select(mask: Self::SimdMask, a: Self::Simd, b: Self::Simd) -> Self::Simd {
+ Float32Type::mask_select(mask, a, b)
+ }
+
+ fn mask_any(mask: Self::SimdMask) -> bool {
+ Float32Type::mask_any(mask)
+ }
+
+ fn bin_op<F: Fn(Self::Simd, Self::Simd) -> Self::Simd>(
+ left: Self::Simd,
+ right: Self::Simd,
+ op: F,
+ ) -> Self::Simd {
+ op(left, right)
+ }
+
+ fn eq(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
+ Float32Type::eq(left, right)
+ }
+
+ fn ne(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
+ Float32Type::ne(left, right)
+ }
+
+ fn lt(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
+ Float32Type::lt(left, right)
+ }
+
+ fn le(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
+ Float32Type::le(left, right)
+ }
+
+ fn gt(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
+ Float32Type::gt(left, right)
+ }
+
+ fn ge(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
+ Float32Type::ge(left, right)
+ }
+
+ fn write(simd_result: Self::Simd, slice: &mut [Self::Native]) {
+ let mut s = [0_f32; Self::Simd::lanes()];
+ Float32Type::write(simd_result, &mut s);
+ slice
+ .iter_mut()
+ .zip(s)
+ .for_each(|(o, i)| *o = half::f16::from_f32(i))
+ }
+
+ fn unary_op<F: Fn(Self::Simd) -> Self::Simd>(a: Self::Simd, op: F) -> Self::Simd {
+ Float32Type::unary_op(a, op)
+ }
+}
+
#[cfg(feature = "simd")]
pub trait ArrowFloatNumericType: ArrowNumericType {
fn pow(base: Self::Simd, raise: Self::Simd) -> Self::Simd;