You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2022/05/09 17:57:26 UTC

[arrow-rs] branch master updated: Add dictionary array support for substring function (#1665)

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new daed6ab58 Add dictionary array support for substring function (#1665)
daed6ab58 is described below

commit daed6ab58a58c84794de53400e3494aa5585e2e5
Author: Chao Sun <su...@apache.org>
AuthorDate: Mon May 9 10:57:21 2022 -0700

    Add dictionary array support for substring function (#1665)
    
    * initial commit
    
    * add test
    
    * comments
    
    * more comments
---
 arrow/src/compute/kernels/substring.rs | 260 ++++++++++++++++++++++-----------
 1 file changed, 175 insertions(+), 85 deletions(-)

diff --git a/arrow/src/compute/kernels/substring.rs b/arrow/src/compute/kernels/substring.rs
index a4d24435f..0ef488d97 100644
--- a/arrow/src/compute/kernels/substring.rs
+++ b/arrow/src/compute/kernels/substring.rs
@@ -18,13 +18,137 @@
 //! Defines kernel to extract a substring of an Array
 //! Supported array types: \[Large\]StringArray, \[Large\]BinaryArray
 
+use crate::array::DictionaryArray;
 use crate::buffer::MutableBuffer;
+use crate::datatypes::*;
 use crate::{array::*, buffer::Buffer};
 use crate::{
     datatypes::DataType,
     error::{ArrowError, Result},
 };
 use std::cmp::Ordering;
+use std::sync::Arc;
+
+/// Returns an ArrayRef with substrings of all the elements in `array`.
+///
+/// # Arguments
+///
+/// * `start` - The start index of all substrings.
+/// If `start >= 0`, then count from the start of the string,
+/// otherwise count from the end of the string.
+///
+/// * `length`(option) - The length of all substrings.
+/// If `length` is `None`, then the substring is from `start` to the end of the string.
+///
+/// Attention: Both `start` and `length` are counted by byte, not by char.
+///
+/// # Basic usage
+/// ```
+/// # use arrow::array::StringArray;
+/// # use arrow::compute::kernels::substring::substring;
+/// let array = StringArray::from(vec![Some("arrow"), None, Some("rust")]);
+/// let result = substring(&array, 1, Some(4)).unwrap();
+/// let result = result.as_any().downcast_ref::<StringArray>().unwrap();
+/// assert_eq!(result, &StringArray::from(vec![Some("rrow"), None, Some("ust")]));
+/// ```
+///
+/// # Error
+/// - The function errors when the passed array is not a \[Large\]String array, \[Large\]Binary
+///   array, or DictionaryArray with \[Large\]String or \[Large\]Binary as its value type.
+/// - The function errors if the offset of a substring in the input array is at invalid char boundary (only for \[Large\]String array).
+///
+/// ## Example of trying to get an invalid utf-8 format substring
+/// ```
+/// # use arrow::array::StringArray;
+/// # use arrow::compute::kernels::substring::substring;
+/// let array = StringArray::from(vec![Some("E=mc²")]);
+/// let error = substring(&array, 0, Some(5)).unwrap_err().to_string();
+/// assert!(error.contains("invalid utf-8 boundary"));
+/// ```
+pub fn substring(array: &dyn Array, start: i64, length: Option<u64>) -> Result<ArrayRef> {
+    macro_rules! substring_dict {
+        ($kt: ident, $($t: ident: $gt: ident), *) => {
+            match $kt.as_ref() {
+                $(
+                    &DataType::$t => {
+                        let dict = array
+                            .as_any()
+                            .downcast_ref::<DictionaryArray<$gt>>()
+                            .unwrap_or_else(|| {
+                                panic!("Expect 'DictionaryArray<{}>' but got array of data type {:?}",
+                                       stringify!($gt), array.data_type())
+                            });
+                        let values = substring(dict.values(), start, length)?;
+                        let result = DictionaryArray::try_new(dict.keys(), &values)?;
+                        Ok(Arc::new(result))
+                    },
+                )*
+                    t => panic!("Unsupported dictionary key type: {}", t)
+            }
+        }
+    }
+
+    match array.data_type() {
+        DataType::Dictionary(kt, _) => {
+            substring_dict!(
+                kt,
+                Int8: Int8Type,
+                Int16: Int16Type,
+                Int32: Int32Type,
+                Int64: Int64Type,
+                UInt8: UInt8Type,
+                UInt16: UInt16Type,
+                UInt32: UInt32Type,
+                UInt64: UInt64Type
+            )
+        }
+        DataType::LargeBinary => binary_substring(
+            array
+                .as_any()
+                .downcast_ref::<LargeBinaryArray>()
+                .expect("A large binary is expected"),
+            start,
+            length.map(|e| e as i64),
+        ),
+        DataType::Binary => binary_substring(
+            array
+                .as_any()
+                .downcast_ref::<BinaryArray>()
+                .expect("A binary is expected"),
+            start as i32,
+            length.map(|e| e as i32),
+        ),
+        DataType::FixedSizeBinary(old_len) => fixed_size_binary_substring(
+            array
+                .as_any()
+                .downcast_ref::<FixedSizeBinaryArray>()
+                .expect("a fixed size binary is expected"),
+            *old_len,
+            start as i32,
+            length.map(|e| e as i32),
+        ),
+        DataType::LargeUtf8 => utf8_substring(
+            array
+                .as_any()
+                .downcast_ref::<LargeStringArray>()
+                .expect("A large string is expected"),
+            start,
+            length.map(|e| e as i64),
+        ),
+        DataType::Utf8 => utf8_substring(
+            array
+                .as_any()
+                .downcast_ref::<StringArray>()
+                .expect("A string is expected"),
+            start as i32,
+            length.map(|e| e as i32),
+        ),
+        _ => Err(ArrowError::ComputeError(format!(
+            "substring does not support type {:?}",
+            array.data_type()
+        ))),
+    }
+}
 
 fn binary_substring<OffsetSize: OffsetSizeTrait>(
     array: &GenericBinaryArray<OffsetSize>,
@@ -215,94 +339,10 @@ fn utf8_substring<OffsetSize: OffsetSizeTrait>(
     Ok(make_array(data))
 }
 
-/// Returns an ArrayRef with substrings of all the elements in `array`.
-///
-/// # Arguments
-///
-/// * `start` - The start index of all substrings.
-/// If `start >= 0`, then count from the start of the string,
-/// otherwise count from the end of the string.
-///
-/// * `length`(option) - The length of all substrings.
-/// If `length` is `None`, then the substring is from `start` to the end of the string.
-///
-/// Attention: Both `start` and `length` are counted by byte, not by char.
-///
-/// # Basic usage
-/// ```
-/// # use arrow::array::StringArray;
-/// # use arrow::compute::kernels::substring::substring;
-/// let array = StringArray::from(vec![Some("arrow"), None, Some("rust")]);
-/// let result = substring(&array, 1, Some(4)).unwrap();
-/// let result = result.as_any().downcast_ref::<StringArray>().unwrap();
-/// assert_eq!(result, &StringArray::from(vec![Some("rrow"), None, Some("ust")]));
-/// ```
-///
-/// # Error
-/// - The function errors when the passed array is not a \[Large\]String array or \[Large\]Binary array.
-/// - The function errors if the offset of a substring in the input array is at invalid char boundary (only for \[Large\]String array).
-///
-/// ## Example of trying to get an invalid utf-8 format substring
-/// ```
-/// # use arrow::array::StringArray;
-/// # use arrow::compute::kernels::substring::substring;
-/// let array = StringArray::from(vec![Some("E=mc²")]);
-/// let error = substring(&array, 0, Some(5)).unwrap_err().to_string();
-/// assert!(error.contains("invalid utf-8 boundary"));
-/// ```
-pub fn substring(array: &dyn Array, start: i64, length: Option<u64>) -> Result<ArrayRef> {
-    match array.data_type() {
-        DataType::LargeBinary => binary_substring(
-            array
-                .as_any()
-                .downcast_ref::<LargeBinaryArray>()
-                .expect("A large binary is expected"),
-            start,
-            length.map(|e| e as i64),
-        ),
-        DataType::Binary => binary_substring(
-            array
-                .as_any()
-                .downcast_ref::<BinaryArray>()
-                .expect("A binary is expected"),
-            start as i32,
-            length.map(|e| e as i32),
-        ),
-        DataType::FixedSizeBinary(old_len) => fixed_size_binary_substring(
-            array
-                .as_any()
-                .downcast_ref::<FixedSizeBinaryArray>()
-                .expect("a fixed size binary is expected"),
-            *old_len,
-            start as i32,
-            length.map(|e| e as i32),
-        ),
-        DataType::LargeUtf8 => utf8_substring(
-            array
-                .as_any()
-                .downcast_ref::<LargeStringArray>()
-                .expect("A large string is expected"),
-            start,
-            length.map(|e| e as i64),
-        ),
-        DataType::Utf8 => utf8_substring(
-            array
-                .as_any()
-                .downcast_ref::<StringArray>()
-                .expect("A string is expected"),
-            start as i32,
-            length.map(|e| e as i32),
-        ),
-        _ => Err(ArrowError::ComputeError(format!(
-            "substring does not support type {:?}",
-            array.data_type()
-        ))),
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::datatypes::*;
 
     #[allow(clippy::type_complexity)]
     fn with_nulls_generic_binary<O: OffsetSizeTrait>() -> Result<()> {
@@ -954,6 +994,56 @@ mod tests {
         without_nulls_generic_string::<i64>()
     }
 
+    #[test]
+    fn dictionary() -> Result<()> {
+        _dictionary::<Int8Type>()?;
+        _dictionary::<Int16Type>()?;
+        _dictionary::<Int32Type>()?;
+        _dictionary::<Int64Type>()?;
+        _dictionary::<UInt8Type>()?;
+        _dictionary::<UInt16Type>()?;
+        _dictionary::<UInt32Type>()?;
+        _dictionary::<UInt64Type>()?;
+        Ok(())
+    }
+
+    fn _dictionary<K: ArrowDictionaryKeyType>() -> Result<()> {
+        const TOTAL: i32 = 100;
+
+        let v = ["aaa", "bbb", "ccc", "ddd", "eee"];
+        let data: Vec<Option<&str>> = (0..TOTAL)
+            .map(|n| {
+                let i = n % 5;
+                if i == 3 {
+                    None
+                } else {
+                    Some(v[i as usize])
+                }
+            })
+            .collect();
+
+        let dict_array: DictionaryArray<K> = data.clone().into_iter().collect();
+
+        let expected: Vec<Option<&str>> =
+            data.iter().map(|opt| opt.map(|s| &s[1..3])).collect();
+
+        let res = substring(&dict_array, 1, Some(2))?;
+        let actual = res.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
+        let actual: Vec<Option<&str>> = actual
+            .values()
+            .as_any()
+            .downcast_ref::<GenericStringArray<i32>>()
+            .unwrap()
+            .take_iter(actual.keys_iter())
+            .collect();
+
+        for i in 0..TOTAL as usize {
+            assert_eq!(expected[i], actual[i],);
+        }
+
+        Ok(())
+    }
+
     #[test]
     fn check_invalid_array_type() {
         let array = Int32Array::from(vec![Some(1), Some(2), Some(3)]);