You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/10/03 14:59:48 UTC

[arrow-rs] branch master updated: Implement ArrowNumericType for Float16Type (#2810)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c1748f9c Implement ArrowNumericType for Float16Type (#2810)
9c1748f9c is described below

commit 9c1748f9cb6a125e18e64bd5da17cc1782a4b2a5
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Mon Oct 3 15:59:42 2022 +0100

    Implement ArrowNumericType for Float16Type (#2810)
    
    * Implement ArrowNumericType for Float16Type
    
    * Remove unnecessary safety comments
---
 arrow/src/compute/kernels/arithmetic.rs | 31 ++++++++---
 arrow/src/compute/kernels/comparison.rs |  7 +--
 arrow/src/datatypes/numeric.rs          | 96 +++++++++++++++++++++++++++++++++
 3 files changed, 123 insertions(+), 11 deletions(-)

diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs
index 1c28c9895..b2e95ad5e 100644
--- a/arrow/src/compute/kernels/arithmetic.rs
+++ b/arrow/src/compute/kernels/arithmetic.rs
@@ -332,9 +332,7 @@ where
 
             // process data in chunks of 64 elements since we also get 64 bits of validity information at a time
 
-            // safety: result is newly created above, always written as a T below
-            let mut result_chunks =
-                unsafe { result.typed_data_mut().chunks_exact_mut(64) };
+            let mut result_chunks = result.typed_data_mut().chunks_exact_mut(64);
             let mut left_chunks = left.values().chunks_exact(64);
             let mut right_chunks = right.values().chunks_exact(64);
 
@@ -380,9 +378,7 @@ where
             )?;
         }
         None => {
-            // safety: result is newly created above, always written as a T below
-            let mut result_chunks =
-                unsafe { result.typed_data_mut().chunks_exact_mut(lanes) };
+            let mut result_chunks = result.typed_data_mut().chunks_exact_mut(lanes);
             let mut left_chunks = left.values().chunks_exact(lanes);
             let mut right_chunks = right.values().chunks_exact(lanes);
 
@@ -1611,6 +1607,7 @@ mod tests {
     use crate::array::Int32Array;
     use crate::datatypes::{Date64Type, Int32Type, Int8Type};
     use chrono::NaiveDate;
+    use half::f16;
 
     #[test]
     fn test_primitive_array_add() {
@@ -2898,4 +2895,26 @@ mod tests {
         let division_by_zero = divide_scalar_opt_dyn::<Int32Type>(&a, 0);
         assert_eq!(&expected, &division_by_zero.unwrap());
     }
+
+    #[test]
+    fn test_sum_f16() {
+        let a = Float16Array::from_iter_values([
+            f16::from_f32(0.1),
+            f16::from_f32(0.2),
+            f16::from_f32(1.5),
+            f16::from_f32(-0.1),
+        ]);
+        let b = Float16Array::from_iter_values([
+            f16::from_f32(5.1),
+            f16::from_f32(6.2),
+            f16::from_f32(-1.),
+            f16::from_f32(-2.1),
+        ]);
+        let expected = Float16Array::from_iter_values(
+            a.values().iter().zip(b.values()).map(|(a, b)| a + b),
+        );
+
+        let c = add(&a, &b).unwrap();
+        assert_eq!(c, expected);
+    }
 }
diff --git a/arrow/src/compute/kernels/comparison.rs b/arrow/src/compute/kernels/comparison.rs
index 49aecfb67..1ea433150 100644
--- a/arrow/src/compute/kernels/comparison.rs
+++ b/arrow/src/compute/kernels/comparison.rs
@@ -1792,7 +1792,6 @@ where
         .iter()
         .map(|key| {
             key.map(|key| unsafe {
-                // safety lengths were verified above
                 let key = key.as_usize();
                 dict_comparison.value_unchecked(key)
             })
@@ -1845,8 +1844,7 @@ where
     let mut left_chunks = left.values().chunks_exact(CHUNK_SIZE);
     let mut right_chunks = right.values().chunks_exact(CHUNK_SIZE);
 
-    // safety: result is newly created above, always written as a T below
-    let result_chunks = unsafe { result.typed_data_mut() };
+    let result_chunks = result.typed_data_mut();
     let result_remainder = left_chunks
         .borrow_mut()
         .zip(right_chunks.borrow_mut())
@@ -1937,8 +1935,7 @@ where
     let mut left_chunks = left.values().chunks_exact(CHUNK_SIZE);
     let simd_right = T::init(right);
 
-    // safety: result is newly created above, always written as a T below
-    let result_chunks = unsafe { result.typed_data_mut() };
+    let result_chunks = result.typed_data_mut();
     let result_remainder =
         left_chunks
             .borrow_mut()
diff --git a/arrow/src/datatypes/numeric.rs b/arrow/src/datatypes/numeric.rs
index b8fa87197..e74764d4c 100644
--- a/arrow/src/datatypes/numeric.rs
+++ b/arrow/src/datatypes/numeric.rs
@@ -366,6 +366,102 @@ make_numeric_type!(DurationMillisecondType, i64, i64x8, m64x8);
 make_numeric_type!(DurationMicrosecondType, i64, i64x8, m64x8);
 make_numeric_type!(DurationNanosecondType, i64, i64x8, m64x8);
 
+#[cfg(not(feature = "simd"))]
+impl ArrowNumericType for Float16Type {}
+
+#[cfg(feature = "simd")]
+impl ArrowNumericType for Float16Type {
+    type Simd = <Float32Type as ArrowNumericType>::Simd;
+    type SimdMask = <Float32Type as ArrowNumericType>::SimdMask;
+
+    fn lanes() -> usize {
+        Float32Type::lanes()
+    }
+
+    fn init(value: Self::Native) -> Self::Simd {
+        Float32Type::init(value.to_f32())
+    }
+
+    fn load(slice: &[Self::Native]) -> Self::Simd {
+        let mut s = [0_f32; Self::Simd::lanes()];
+        s.iter_mut().zip(slice).for_each(|(o, a)| *o = a.to_f32());
+        Float32Type::load(&s)
+    }
+
+    fn mask_init(value: bool) -> Self::SimdMask {
+        Float32Type::mask_init(value)
+    }
+
+    fn mask_from_u64(mask: u64) -> Self::SimdMask {
+        Float32Type::mask_from_u64(mask)
+    }
+
+    fn mask_to_u64(mask: &Self::SimdMask) -> u64 {
+        Float32Type::mask_to_u64(mask)
+    }
+
+    fn mask_get(mask: &Self::SimdMask, idx: usize) -> bool {
+        Float32Type::mask_get(mask, idx)
+    }
+
+    fn mask_set(mask: Self::SimdMask, idx: usize, value: bool) -> Self::SimdMask {
+        Float32Type::mask_set(mask, idx, value)
+    }
+
+    fn mask_select(mask: Self::SimdMask, a: Self::Simd, b: Self::Simd) -> Self::Simd {
+        Float32Type::mask_select(mask, a, b)
+    }
+
+    fn mask_any(mask: Self::SimdMask) -> bool {
+        Float32Type::mask_any(mask)
+    }
+
+    fn bin_op<F: Fn(Self::Simd, Self::Simd) -> Self::Simd>(
+        left: Self::Simd,
+        right: Self::Simd,
+        op: F,
+    ) -> Self::Simd {
+        op(left, right)
+    }
+
+    fn eq(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
+        Float32Type::eq(left, right)
+    }
+
+    fn ne(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
+        Float32Type::ne(left, right)
+    }
+
+    fn lt(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
+        Float32Type::lt(left, right)
+    }
+
+    fn le(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
+        Float32Type::le(left, right)
+    }
+
+    fn gt(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
+        Float32Type::gt(left, right)
+    }
+
+    fn ge(left: Self::Simd, right: Self::Simd) -> Self::SimdMask {
+        Float32Type::ge(left, right)
+    }
+
+    fn write(simd_result: Self::Simd, slice: &mut [Self::Native]) {
+        let mut s = [0_f32; Self::Simd::lanes()];
+        Float32Type::write(simd_result, &mut s);
+        slice
+            .iter_mut()
+            .zip(s)
+            .for_each(|(o, i)| *o = half::f16::from_f32(i))
+    }
+
+    fn unary_op<F: Fn(Self::Simd) -> Self::Simd>(a: Self::Simd, op: F) -> Self::Simd {
+        Float32Type::unary_op(a, op)
+    }
+}
+
 #[cfg(feature = "simd")]
 pub trait ArrowFloatNumericType: ArrowNumericType {
     fn pow(base: Self::Simd, raise: Self::Simd) -> Self::Simd;