You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/04/22 12:44:41 UTC

[arrow-rs] branch master updated: ARROW-12426: [Rust] Fix concatentation of arrow dictionaries (#15)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new a5732e8  ARROW-12426: [Rust] Fix concatentation of arrow dictionaries (#15)
a5732e8 is described below

commit a5732e80508c60cb65cc049c57f8bca5d38fc1c4
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Thu Apr 22 13:44:34 2021 +0100

    ARROW-12426: [Rust] Fix concatentation of arrow dictionaries (#15)
---
 arrow/src/array/transform/mod.rs       | 120 +++++++++++++++++++++++++++++----
 arrow/src/array/transform/primitive.rs |  15 +++++
 arrow/src/compute/kernels/concat.rs    |  68 +++++++++++++++++++
 3 files changed, 189 insertions(+), 14 deletions(-)

diff --git a/arrow/src/array/transform/mod.rs b/arrow/src/array/transform/mod.rs
index 4dc7b56..e7ec41e 100644
--- a/arrow/src/array/transform/mod.rs
+++ b/arrow/src/array/transform/mod.rs
@@ -15,7 +15,12 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::{buffer::MutableBuffer, datatypes::DataType, util::bit_util};
+use crate::{
+    buffer::MutableBuffer,
+    datatypes::DataType,
+    error::{ArrowError, Result},
+    util::bit_util,
+};
 
 use super::{
     data::{into_buffers, new_buffers},
@@ -166,6 +171,65 @@ impl<'a> std::fmt::Debug for MutableArrayData<'a> {
     }
 }
 
+/// Builds an extend that adds `offset` to the source primitive
+/// Additionally validates that `max` fits into the
+/// the underlying primitive returning None if not
+fn build_extend_dictionary(
+    array: &ArrayData,
+    offset: usize,
+    max: usize,
+) -> Option<Extend> {
+    use crate::datatypes::*;
+    use std::convert::TryInto;
+
+    match array.data_type() {
+        DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() {
+            DataType::UInt8 => {
+                let _: u8 = max.try_into().ok()?;
+                let offset: u8 = offset.try_into().ok()?;
+                Some(primitive::build_extend_with_offset(array, offset))
+            }
+            DataType::UInt16 => {
+                let _: u16 = max.try_into().ok()?;
+                let offset: u16 = offset.try_into().ok()?;
+                Some(primitive::build_extend_with_offset(array, offset))
+            }
+            DataType::UInt32 => {
+                let _: u32 = max.try_into().ok()?;
+                let offset: u32 = offset.try_into().ok()?;
+                Some(primitive::build_extend_with_offset(array, offset))
+            }
+            DataType::UInt64 => {
+                let _: u64 = max.try_into().ok()?;
+                let offset: u64 = offset.try_into().ok()?;
+                Some(primitive::build_extend_with_offset(array, offset))
+            }
+            DataType::Int8 => {
+                let _: i8 = max.try_into().ok()?;
+                let offset: i8 = offset.try_into().ok()?;
+                Some(primitive::build_extend_with_offset(array, offset))
+            }
+            DataType::Int16 => {
+                let _: i16 = max.try_into().ok()?;
+                let offset: i16 = offset.try_into().ok()?;
+                Some(primitive::build_extend_with_offset(array, offset))
+            }
+            DataType::Int32 => {
+                let _: i32 = max.try_into().ok()?;
+                let offset: i32 = offset.try_into().ok()?;
+                Some(primitive::build_extend_with_offset(array, offset))
+            }
+            DataType::Int64 => {
+                let _: i64 = max.try_into().ok()?;
+                let offset: i64 = offset.try_into().ok()?;
+                Some(primitive::build_extend_with_offset(array, offset))
+            }
+            _ => unreachable!(),
+        },
+        _ => None,
+    }
+}
+
 fn build_extend(array: &ArrayData) -> Extend {
     use crate::datatypes::*;
     match array.data_type() {
@@ -199,17 +263,7 @@ fn build_extend(array: &ArrayData) -> Extend {
         }
         DataType::List(_) => list::build_extend::<i32>(array),
         DataType::LargeList(_) => list::build_extend::<i64>(array),
-        DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() {
-            DataType::UInt8 => primitive::build_extend::<u8>(array),
-            DataType::UInt16 => primitive::build_extend::<u16>(array),
-            DataType::UInt32 => primitive::build_extend::<u32>(array),
-            DataType::UInt64 => primitive::build_extend::<u64>(array),
-            DataType::Int8 => primitive::build_extend::<i8>(array),
-            DataType::Int16 => primitive::build_extend::<i16>(array),
-            DataType::Int32 => primitive::build_extend::<i32>(array),
-            DataType::Int64 => primitive::build_extend::<i64>(array),
-            _ => unreachable!(),
-        },
+        DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"),
         DataType::Struct(_) => structure::build_extend(array),
         DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
         DataType::Float16 => unreachable!(),
@@ -339,7 +393,29 @@ impl<'a> MutableArrayData<'a> {
         };
 
         let dictionary = match &data_type {
-            DataType::Dictionary(_, _) => Some(arrays[0].child_data()[0].clone()),
+            DataType::Dictionary(_, _) => match arrays.len() {
+                0 => unreachable!(),
+                1 => Some(arrays[0].child_data()[0].clone()),
+                _ => {
+                    // Concat dictionaries together
+                    let dictionaries: Vec<_> =
+                        arrays.iter().map(|array| &array.child_data()[0]).collect();
+                    let lengths: Vec<_> = dictionaries
+                        .iter()
+                        .map(|dictionary| dictionary.len())
+                        .collect();
+                    let capacity = lengths.iter().sum();
+
+                    let mut mutable =
+                        MutableArrayData::new(dictionaries, false, capacity);
+
+                    for (i, len) in lengths.iter().enumerate() {
+                        mutable.extend(i, 0, *len)
+                    }
+
+                    Some(mutable.freeze())
+                }
+            },
             _ => None,
         };
 
@@ -353,7 +429,23 @@ impl<'a> MutableArrayData<'a> {
         let null_bytes = bit_util::ceil(capacity, 8);
         let null_buffer = MutableBuffer::from_len_zeroed(null_bytes);
 
-        let extend_values = arrays.iter().map(|array| build_extend(array)).collect();
+        let extend_values = match &data_type {
+            DataType::Dictionary(_, _) => {
+                let mut next_offset = 0;
+                let extend_values: Result<Vec<_>> = arrays
+                    .iter()
+                    .map(|array| {
+                        let offset = next_offset;
+                        next_offset += array.child_data()[0].len();
+                        build_extend_dictionary(array, offset, next_offset)
+                            .ok_or(ArrowError::DictionaryKeyOverflowError)
+                    })
+                    .collect();
+
+                extend_values.expect("MutableArrayData::new is infallible")
+            }
+            _ => arrays.iter().map(|array| build_extend(array)).collect(),
+        };
 
         let data = _MutableArrayData {
             data_type: data_type.clone(),
diff --git a/arrow/src/array/transform/primitive.rs b/arrow/src/array/transform/primitive.rs
index 032bb4a..4c765c0 100644
--- a/arrow/src/array/transform/primitive.rs
+++ b/arrow/src/array/transform/primitive.rs
@@ -16,6 +16,7 @@
 // under the License.
 
 use std::mem::size_of;
+use std::ops::Add;
 
 use crate::{array::ArrayData, datatypes::ArrowNativeType};
 
@@ -32,6 +33,20 @@ pub(super) fn build_extend<T: ArrowNativeType>(array: &ArrayData) -> Extend {
     )
 }
 
+pub(super) fn build_extend_with_offset<T>(array: &ArrayData, offset: T) -> Extend
+where
+    T: ArrowNativeType + Add<Output = T>,
+{
+    let values = array.buffer::<T>(0);
+    Box::new(
+        move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| {
+            mutable
+                .buffer1
+                .extend(values[start..start + len].iter().map(|x| *x + offset));
+        },
+    )
+}
+
 pub(super) fn extend_nulls<T: ArrowNativeType>(
     mutable: &mut _MutableArrayData,
     len: usize,
diff --git a/arrow/src/compute/kernels/concat.rs b/arrow/src/compute/kernels/concat.rs
index 3288028..35ff183 100644
--- a/arrow/src/compute/kernels/concat.rs
+++ b/arrow/src/compute/kernels/concat.rs
@@ -384,4 +384,72 @@ mod tests {
 
         Ok(())
     }
+
+    fn collect_string_dictionary(
+        dictionary: &DictionaryArray<Int32Type>,
+    ) -> Vec<Option<String>> {
+        let values = dictionary.values();
+        let values = values.as_any().downcast_ref::<StringArray>().unwrap();
+
+        dictionary
+            .keys()
+            .iter()
+            .map(|key| key.map(|key| values.value(key as _).to_string()))
+            .collect()
+    }
+
+    fn concat_dictionary(
+        input_1: DictionaryArray<Int32Type>,
+        input_2: DictionaryArray<Int32Type>,
+    ) -> Vec<Option<String>> {
+        let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
+        let concat = concat
+            .as_any()
+            .downcast_ref::<DictionaryArray<Int32Type>>()
+            .unwrap();
+
+        collect_string_dictionary(concat)
+    }
+
+    #[test]
+    fn test_string_dictionary_array() {
+        let input_1: DictionaryArray<Int32Type> =
+            vec!["hello", "A", "B", "hello", "hello", "C"]
+                .into_iter()
+                .collect();
+        let input_2: DictionaryArray<Int32Type> =
+            vec!["hello", "E", "E", "hello", "F", "E"]
+                .into_iter()
+                .collect();
+
+        let expected: Vec<_> = vec![
+            "hello", "A", "B", "hello", "hello", "C", "hello", "E", "E", "hello", "F",
+            "E",
+        ]
+        .into_iter()
+        .map(|x| Some(x.to_string()))
+        .collect();
+
+        let concat = concat_dictionary(input_1, input_2);
+        assert_eq!(concat, expected);
+    }
+
+    #[test]
+    fn test_string_dictionary_array_nulls() {
+        let input_1: DictionaryArray<Int32Type> =
+            vec![Some("foo"), Some("bar"), None, Some("fiz")]
+                .into_iter()
+                .collect();
+        let input_2: DictionaryArray<Int32Type> = vec![None].into_iter().collect();
+        let expected = vec![
+            Some("foo".to_string()),
+            Some("bar".to_string()),
+            None,
+            Some("fiz".to_string()),
+            None,
+        ];
+
+        let concat = concat_dictionary(input_1, input_2);
+        assert_eq!(concat, expected);
+    }
 }