You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2022/07/06 20:28:14 UTC
[arrow-rs] branch master updated: Support DictionaryArray in unary kernel (#1990)
This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 62053a801 Support DictionaryArray in unary kernel (#1990)
62053a801 is described below
commit 62053a801e0a8e6b22778314c19a37929e96b76a
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Wed Jul 6 13:28:10 2022 -0700
Support DictionaryArray in unary kernel (#1990)
* Init
* More
* Fix clippy
* Apply on dictionary values directly in unary_dict.
* Fix clippy
* Avoid validate when constructing new dictionary array
---
arrow/src/compute/kernels/arity.rs | 181 ++++++++++++++++++++++++++++++++++++-
1 file changed, 177 insertions(+), 4 deletions(-)
diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs
index 60a0cb77f..513521816 100644
--- a/arrow/src/compute/kernels/arity.rs
+++ b/arrow/src/compute/kernels/arity.rs
@@ -17,9 +17,14 @@
//! Defines kernels suitable to perform operations to primitive arrays.
-use crate::array::{Array, ArrayData, PrimitiveArray};
+use crate::array::{Array, ArrayData, ArrayRef, DictionaryArray, PrimitiveArray};
use crate::buffer::Buffer;
-use crate::datatypes::ArrowPrimitiveType;
+use crate::datatypes::{
+ ArrowNumericType, ArrowPrimitiveType, DataType, Int16Type, Int32Type, Int64Type,
+ Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+};
+use crate::error::{ArrowError, Result};
+use std::sync::Arc;
#[inline]
fn into_primitive_array_data<I: ArrowPrimitiveType, O: ArrowPrimitiveType>(
@@ -78,10 +83,128 @@ where
PrimitiveArray::<O>::from(data)
}
+/// A helper function that applies an unary function to a dictionary array with primitive value type.
+#[allow(clippy::redundant_closure)]
+fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
+where
+ K: ArrowNumericType,
+ T: ArrowPrimitiveType,
+ F: Fn(T::Native) -> T::Native,
+{
+ let dict_values = array
+ .values()
+ .as_any()
+ .downcast_ref::<PrimitiveArray<T>>()
+ .unwrap();
+
+ let values = dict_values
+ .iter()
+ .map(|v| v.map(|value| op(value)))
+ .collect::<PrimitiveArray<T>>();
+
+ let keys = array.keys();
+
+ let mut data = ArrayData::builder(array.data_type().clone())
+ .len(keys.len())
+ .add_buffer(keys.data().buffers()[0].clone())
+ .add_child_data(values.data().clone());
+
+ match keys.data().null_buffer() {
+ Some(buffer) if keys.data().null_count() > 0 => {
+ data = data
+ .null_bit_buffer(Some(buffer.clone()))
+ .null_count(keys.data().null_count());
+ }
+ _ => data = data.null_count(0),
+ }
+
+ let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked() }.into();
+ Ok(Arc::new(new_dict))
+}
+
+/// Applies an unary function to an array with primitive values.
+pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
+where
+ T: ArrowPrimitiveType,
+ F: Fn(T::Native) -> T::Native,
+{
+ match array.data_type() {
+ DataType::Dictionary(key_type, _) => match key_type.as_ref() {
+ DataType::Int8 => unary_dict::<_, F, T>(
+ array
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int8Type>>()
+ .unwrap(),
+ op,
+ ),
+ DataType::Int16 => unary_dict::<_, F, T>(
+ array
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int16Type>>()
+ .unwrap(),
+ op,
+ ),
+ DataType::Int32 => unary_dict::<_, F, T>(
+ array
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int32Type>>()
+ .unwrap(),
+ op,
+ ),
+ DataType::Int64 => unary_dict::<_, F, T>(
+ array
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int64Type>>()
+ .unwrap(),
+ op,
+ ),
+ DataType::UInt8 => unary_dict::<_, F, T>(
+ array
+ .as_any()
+ .downcast_ref::<DictionaryArray<UInt8Type>>()
+ .unwrap(),
+ op,
+ ),
+ DataType::UInt16 => unary_dict::<_, F, T>(
+ array
+ .as_any()
+ .downcast_ref::<DictionaryArray<UInt16Type>>()
+ .unwrap(),
+ op,
+ ),
+ DataType::UInt32 => unary_dict::<_, F, T>(
+ array
+ .as_any()
+ .downcast_ref::<DictionaryArray<UInt32Type>>()
+ .unwrap(),
+ op,
+ ),
+ DataType::UInt64 => unary_dict::<_, F, T>(
+ array
+ .as_any()
+ .downcast_ref::<DictionaryArray<UInt64Type>>()
+ .unwrap(),
+ op,
+ ),
+ t => Err(ArrowError::NotYetImplemented(format!(
+ "Cannot perform unary operation on dictionary array of key type {}.",
+ t
+ ))),
+ },
+ _ => Ok(Arc::new(unary::<T, F, T>(
+ array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
+ op,
+ ))),
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
- use crate::array::{as_primitive_array, Float64Array};
+ use crate::array::{
+ as_primitive_array, Float64Array, PrimitiveBuilder, PrimitiveDictionaryBuilder,
+ };
+ use crate::datatypes::{Float64Type, Int32Type, Int8Type};
#[test]
fn test_unary_f64_slice() {
@@ -93,6 +216,56 @@ mod tests {
assert_eq!(
result,
Float64Array::from(vec![None, Some(7.0), None, Some(7.0)])
- )
+ );
+
+ let result = unary_dyn::<_, Float64Type>(input_slice, |n| n + 1.0).unwrap();
+
+ assert_eq!(
+ result.as_any().downcast_ref::<Float64Array>().unwrap(),
+ &Float64Array::from(vec![None, Some(7.8), None, Some(8.2)])
+ );
+ }
+
+ #[test]
+ fn test_unary_dict_and_unary_dyn() {
+ let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
+ let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
+ let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
+ builder.append(5).unwrap();
+ builder.append(6).unwrap();
+ builder.append(7).unwrap();
+ builder.append(8).unwrap();
+ builder.append_null().unwrap();
+ builder.append(9).unwrap();
+ let dictionary_array = builder.finish();
+
+ let key_builder = PrimitiveBuilder::<Int8Type>::new(3);
+ let value_builder = PrimitiveBuilder::<Int32Type>::new(2);
+ let mut builder = PrimitiveDictionaryBuilder::new(key_builder, value_builder);
+ builder.append(6).unwrap();
+ builder.append(7).unwrap();
+ builder.append(8).unwrap();
+ builder.append(9).unwrap();
+ builder.append_null().unwrap();
+ builder.append(10).unwrap();
+ let expected = builder.finish();
+
+ let result = unary_dict::<_, _, Int32Type>(&dictionary_array, |n| n + 1).unwrap();
+ assert_eq!(
+ result
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int8Type>>()
+ .unwrap(),
+ &expected
+ );
+
+ let result = unary_dyn::<_, Int32Type>(&dictionary_array, |n| n + 1).unwrap();
+ assert_eq!(
+ result
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int8Type>>()
+ .unwrap(),
+ &expected
+ );
}
}