You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2022/07/06 20:28:14 UTC

[arrow-rs] branch master updated: Support DictionaryArray in unary kernel (#1990)

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 62053a801 Support DictionaryArray in unary kernel (#1990)
62053a801 is described below

commit 62053a801e0a8e6b22778314c19a37929e96b76a
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Wed Jul 6 13:28:10 2022 -0700

    Support DictionaryArray in unary kernel (#1990)
    
    * Init
    
    * More
    
    * Fix clippy
    
    * Apply on dictionary values directly in unary_dict.
    
    * Fix clippy
    
    * Avoid validate when constructing new dictionary array
---
 arrow/src/compute/kernels/arity.rs | 181 ++++++++++++++++++++++++++++++++++++-
 1 file changed, 177 insertions(+), 4 deletions(-)

diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs
index 60a0cb77f..513521816 100644
--- a/arrow/src/compute/kernels/arity.rs
+++ b/arrow/src/compute/kernels/arity.rs
@@ -17,9 +17,14 @@
 
 //! Defines kernels suitable to perform operations to primitive arrays.
 
-use crate::array::{Array, ArrayData, PrimitiveArray};
+use crate::array::{Array, ArrayData, ArrayRef, DictionaryArray, PrimitiveArray};
 use crate::buffer::Buffer;
-use crate::datatypes::ArrowPrimitiveType;
+use crate::datatypes::{
+    ArrowNumericType, ArrowPrimitiveType, DataType, Int16Type, Int32Type, Int64Type,
+    Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+};
+use crate::error::{ArrowError, Result};
+use std::sync::Arc;
 
 #[inline]
 fn into_primitive_array_data<I: ArrowPrimitiveType, O: ArrowPrimitiveType>(
@@ -78,10 +83,128 @@ where
     PrimitiveArray::<O>::from(data)
 }
 
+/// A helper function that applies an unary function to a dictionary array with primitive value type.
+#[allow(clippy::redundant_closure)]
+fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
+where
+    K: ArrowNumericType,
+    T: ArrowPrimitiveType,
+    F: Fn(T::Native) -> T::Native,
+{
+    let dict_values = array
+        .values()
+        .as_any()
+        .downcast_ref::<PrimitiveArray<T>>()
+        .unwrap();
+
+    let values = dict_values
+        .iter()
+        .map(|v| v.map(|value| op(value)))
+        .collect::<PrimitiveArray<T>>();
+
+    let keys = array.keys();
+
+    let mut data = ArrayData::builder(array.data_type().clone())
+        .len(keys.len())
+        .add_buffer(keys.data().buffers()[0].clone())
+        .add_child_data(values.data().clone());
+
+    match keys.data().null_buffer() {
+        Some(buffer) if keys.data().null_count() > 0 => {
+            data = data
+                .null_bit_buffer(Some(buffer.clone()))
+                .null_count(keys.data().null_count());
+        }
+        _ => data = data.null_count(0),
+    }
+
+    let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked() }.into();
+    Ok(Arc::new(new_dict))
+}
+
+/// Applies an unary function to an array with primitive values.
+pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
+where
+    T: ArrowPrimitiveType,
+    F: Fn(T::Native) -> T::Native,
+{
+    match array.data_type() {
+        DataType::Dictionary(key_type, _) => match key_type.as_ref() {
+            DataType::Int8 => unary_dict::<_, F, T>(
+                array
+                    .as_any()
+                    .downcast_ref::<DictionaryArray<Int8Type>>()
+                    .unwrap(),
+                op,
+            ),
+            DataType::Int16 => unary_dict::<_, F, T>(
+                array
+                    .as_any()
+                    .downcast_ref::<DictionaryArray<Int16Type>>()
+                    .unwrap(),
+                op,
+            ),
+            DataType::Int32 => unary_dict::<_, F, T>(
+                array
+                    .as_any()
+                    .downcast_ref::<DictionaryArray<Int32Type>>()
+                    .unwrap(),
+                op,
+            ),
+            DataType::Int64 => unary_dict::<_, F, T>(
+                array
+                    .as_any()
+                    .downcast_ref::<DictionaryArray<Int64Type>>()
+                    .unwrap(),
+                op,
+            ),
+            DataType::UInt8 => unary_dict::<_, F, T>(
+                array
+                    .as_any()
+                    .downcast_ref::<DictionaryArray<UInt8Type>>()
+                    .unwrap(),
+                op,
+            ),
+            DataType::UInt16 => unary_dict::<_, F, T>(
+                array
+                    .as_any()
+                    .downcast_ref::<DictionaryArray<UInt16Type>>()
+                    .unwrap(),
+                op,
+            ),
+            DataType::UInt32 => unary_dict::<_, F, T>(
+                array
+                    .as_any()
+                    .downcast_ref::<DictionaryArray<UInt32Type>>()
+                    .unwrap(),
+                op,
+            ),
+            DataType::UInt64 => unary_dict::<_, F, T>(
+                array
+                    .as_any()
+                    .downcast_ref::<DictionaryArray<UInt64Type>>()
+                    .unwrap(),
+                op,
+            ),
+            t => Err(ArrowError::NotYetImplemented(format!(
+                "Cannot perform unary operation on dictionary array of key type {}.",
+                t
+            ))),
+        },
+        _ => Ok(Arc::new(unary::<T, F, T>(
+            array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
+            op,
+        ))),
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::array::{as_primitive_array, Float64Array};
+    use crate::array::{
+        as_primitive_array, Float64Array, PrimitiveBuilder, PrimitiveDictionaryBuilder,
+    };
+    use crate::datatypes::{Float64Type, Int32Type, Int8Type};
 
     #[test]
     fn test_unary_f64_slice() {
@@ -93,6 +216,56 @@ mod tests {
         assert_eq!(
             result,
             Float64Array::from(vec![None, Some(7.0), None, Some(7.0)])
-        )
+        );
+
+        let result = unary_dyn::<_, Float64Type>(input_slice, |n| n + 1.0).unwrap();
+
+        assert_eq!(
+            result.as_any().downcast_ref::<Float64Array>().unwrap(),
+            &Float64Array::from(vec![None, Some(7.8), None, Some(8.2)])
+        );
+    }
+
+    #[test]
+    fn test_unary_dict_and_unary_dyn() {
+        let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
+        let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
+        let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
+        builder.append(5).unwrap();
+        builder.append(6).unwrap();
+        builder.append(7).unwrap();
+        builder.append(8).unwrap();
+        builder.append_null().unwrap();
+        builder.append(9).unwrap();
+        let dictionary_array = builder.finish();
+
+        let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
+        let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
+        let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
+        builder.append(6).unwrap();
+        builder.append(7).unwrap();
+        builder.append(8).unwrap();
+        builder.append(9).unwrap();
+        builder.append_null().unwrap();
+        builder.append(10).unwrap();
+        let expected = builder.finish();
+
+        let result = unary_dict::<_, _, Int32Type>(&dictionary_array, |n| n + 1).unwrap();
+        assert_eq!(
+            result
+                .as_any()
+                .downcast_ref::<DictionaryArray<Int8Type>>()
+                .unwrap(),
+            &expected
+        );
+
+        let result = unary_dyn::<_, Int32Type>(&dictionary_array, |n| n + 1).unwrap();
+        assert_eq!(
+            result
+                .as_any()
+                .downcast_ref::<DictionaryArray<Int8Type>>()
+                .unwrap(),
+            &expected
+        );
     }
 }