You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/04/22 12:44:41 UTC
[arrow-rs] branch master updated: ARROW-12426: [Rust] Fix
concatentation of arrow dictionaries (#15)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new a5732e8 ARROW-12426: [Rust] Fix concatentation of arrow dictionaries (#15)
a5732e8 is described below
commit a5732e80508c60cb65cc049c57f8bca5d38fc1c4
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Thu Apr 22 13:44:34 2021 +0100
ARROW-12426: [Rust] Fix concatentation of arrow dictionaries (#15)
---
arrow/src/array/transform/mod.rs | 120 +++++++++++++++++++++++++++++----
arrow/src/array/transform/primitive.rs | 15 +++++
arrow/src/compute/kernels/concat.rs | 68 +++++++++++++++++++
3 files changed, 189 insertions(+), 14 deletions(-)
diff --git a/arrow/src/array/transform/mod.rs b/arrow/src/array/transform/mod.rs
index 4dc7b56..e7ec41e 100644
--- a/arrow/src/array/transform/mod.rs
+++ b/arrow/src/array/transform/mod.rs
@@ -15,7 +15,12 @@
// specific language governing permissions and limitations
// under the License.
-use crate::{buffer::MutableBuffer, datatypes::DataType, util::bit_util};
+use crate::{
+ buffer::MutableBuffer,
+ datatypes::DataType,
+ error::{ArrowError, Result},
+ util::bit_util,
+};
use super::{
data::{into_buffers, new_buffers},
@@ -166,6 +171,65 @@ impl<'a> std::fmt::Debug for MutableArrayData<'a> {
}
}
+/// Builds an extend that adds `offset` to the source primitive
+/// Additionally validates that `max` fits into the
+/// the underlying primitive returning None if not
+fn build_extend_dictionary(
+ array: &ArrayData,
+ offset: usize,
+ max: usize,
+) -> Option<Extend> {
+ use crate::datatypes::*;
+ use std::convert::TryInto;
+
+ match array.data_type() {
+ DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() {
+ DataType::UInt8 => {
+ let _: u8 = max.try_into().ok()?;
+ let offset: u8 = offset.try_into().ok()?;
+ Some(primitive::build_extend_with_offset(array, offset))
+ }
+ DataType::UInt16 => {
+ let _: u16 = max.try_into().ok()?;
+ let offset: u16 = offset.try_into().ok()?;
+ Some(primitive::build_extend_with_offset(array, offset))
+ }
+ DataType::UInt32 => {
+ let _: u32 = max.try_into().ok()?;
+ let offset: u32 = offset.try_into().ok()?;
+ Some(primitive::build_extend_with_offset(array, offset))
+ }
+ DataType::UInt64 => {
+ let _: u64 = max.try_into().ok()?;
+ let offset: u64 = offset.try_into().ok()?;
+ Some(primitive::build_extend_with_offset(array, offset))
+ }
+ DataType::Int8 => {
+ let _: i8 = max.try_into().ok()?;
+ let offset: i8 = offset.try_into().ok()?;
+ Some(primitive::build_extend_with_offset(array, offset))
+ }
+ DataType::Int16 => {
+ let _: i16 = max.try_into().ok()?;
+ let offset: i16 = offset.try_into().ok()?;
+ Some(primitive::build_extend_with_offset(array, offset))
+ }
+ DataType::Int32 => {
+ let _: i32 = max.try_into().ok()?;
+ let offset: i32 = offset.try_into().ok()?;
+ Some(primitive::build_extend_with_offset(array, offset))
+ }
+ DataType::Int64 => {
+ let _: i64 = max.try_into().ok()?;
+ let offset: i64 = offset.try_into().ok()?;
+ Some(primitive::build_extend_with_offset(array, offset))
+ }
+ _ => unreachable!(),
+ },
+ _ => None,
+ }
+}
+
fn build_extend(array: &ArrayData) -> Extend {
use crate::datatypes::*;
match array.data_type() {
@@ -199,17 +263,7 @@ fn build_extend(array: &ArrayData) -> Extend {
}
DataType::List(_) => list::build_extend::<i32>(array),
DataType::LargeList(_) => list::build_extend::<i64>(array),
- DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() {
- DataType::UInt8 => primitive::build_extend::<u8>(array),
- DataType::UInt16 => primitive::build_extend::<u16>(array),
- DataType::UInt32 => primitive::build_extend::<u32>(array),
- DataType::UInt64 => primitive::build_extend::<u64>(array),
- DataType::Int8 => primitive::build_extend::<i8>(array),
- DataType::Int16 => primitive::build_extend::<i16>(array),
- DataType::Int32 => primitive::build_extend::<i32>(array),
- DataType::Int64 => primitive::build_extend::<i64>(array),
- _ => unreachable!(),
- },
+ DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"),
DataType::Struct(_) => structure::build_extend(array),
DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
DataType::Float16 => unreachable!(),
@@ -339,7 +393,29 @@ impl<'a> MutableArrayData<'a> {
};
let dictionary = match &data_type {
- DataType::Dictionary(_, _) => Some(arrays[0].child_data()[0].clone()),
+ DataType::Dictionary(_, _) => match arrays.len() {
+ 0 => unreachable!(),
+ 1 => Some(arrays[0].child_data()[0].clone()),
+ _ => {
+ // Concat dictionaries together
+ let dictionaries: Vec<_> =
+ arrays.iter().map(|array| &array.child_data()[0]).collect();
+ let lengths: Vec<_> = dictionaries
+ .iter()
+ .map(|dictionary| dictionary.len())
+ .collect();
+ let capacity = lengths.iter().sum();
+
+ let mut mutable =
+ MutableArrayData::new(dictionaries, false, capacity);
+
+ for (i, len) in lengths.iter().enumerate() {
+ mutable.extend(i, 0, *len)
+ }
+
+ Some(mutable.freeze())
+ }
+ },
_ => None,
};
@@ -353,7 +429,23 @@ impl<'a> MutableArrayData<'a> {
let null_bytes = bit_util::ceil(capacity, 8);
let null_buffer = MutableBuffer::from_len_zeroed(null_bytes);
- let extend_values = arrays.iter().map(|array| build_extend(array)).collect();
+ let extend_values = match &data_type {
+ DataType::Dictionary(_, _) => {
+ let mut next_offset = 0;
+ let extend_values: Result<Vec<_>> = arrays
+ .iter()
+ .map(|array| {
+ let offset = next_offset;
+ next_offset += array.child_data()[0].len();
+ build_extend_dictionary(array, offset, next_offset)
+ .ok_or(ArrowError::DictionaryKeyOverflowError)
+ })
+ .collect();
+
+ extend_values.expect("MutableArrayData::new is infallible")
+ }
+ _ => arrays.iter().map(|array| build_extend(array)).collect(),
+ };
let data = _MutableArrayData {
data_type: data_type.clone(),
diff --git a/arrow/src/array/transform/primitive.rs b/arrow/src/array/transform/primitive.rs
index 032bb4a..4c765c0 100644
--- a/arrow/src/array/transform/primitive.rs
+++ b/arrow/src/array/transform/primitive.rs
@@ -16,6 +16,7 @@
// under the License.
use std::mem::size_of;
+use std::ops::Add;
use crate::{array::ArrayData, datatypes::ArrowNativeType};
@@ -32,6 +33,20 @@ pub(super) fn build_extend<T: ArrowNativeType>(array: &ArrayData) -> Extend {
)
}
+pub(super) fn build_extend_with_offset<T>(array: &ArrayData, offset: T) -> Extend
+where
+ T: ArrowNativeType + Add<Output = T>,
+{
+ let values = array.buffer::<T>(0);
+ Box::new(
+ move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| {
+ mutable
+ .buffer1
+ .extend(values[start..start + len].iter().map(|x| *x + offset));
+ },
+ )
+}
+
pub(super) fn extend_nulls<T: ArrowNativeType>(
mutable: &mut _MutableArrayData,
len: usize,
diff --git a/arrow/src/compute/kernels/concat.rs b/arrow/src/compute/kernels/concat.rs
index 3288028..35ff183 100644
--- a/arrow/src/compute/kernels/concat.rs
+++ b/arrow/src/compute/kernels/concat.rs
@@ -384,4 +384,72 @@ mod tests {
Ok(())
}
+
+ fn collect_string_dictionary(
+ dictionary: &DictionaryArray<Int32Type>,
+ ) -> Vec<Option<String>> {
+ let values = dictionary.values();
+ let values = values.as_any().downcast_ref::<StringArray>().unwrap();
+
+ dictionary
+ .keys()
+ .iter()
+ .map(|key| key.map(|key| values.value(key as _).to_string()))
+ .collect()
+ }
+
+ fn concat_dictionary(
+ input_1: DictionaryArray<Int32Type>,
+ input_2: DictionaryArray<Int32Type>,
+ ) -> Vec<Option<String>> {
+ let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
+ let concat = concat
+ .as_any()
+ .downcast_ref::<DictionaryArray<Int32Type>>()
+ .unwrap();
+
+ collect_string_dictionary(concat)
+ }
+
+ #[test]
+ fn test_string_dictionary_array() {
+ let input_1: DictionaryArray<Int32Type> =
+ vec!["hello", "A", "B", "hello", "hello", "C"]
+ .into_iter()
+ .collect();
+ let input_2: DictionaryArray<Int32Type> =
+ vec!["hello", "E", "E", "hello", "F", "E"]
+ .into_iter()
+ .collect();
+
+ let expected: Vec<_> = vec![
+ "hello", "A", "B", "hello", "hello", "C", "hello", "E", "E", "hello", "F",
+ "E",
+ ]
+ .into_iter()
+ .map(|x| Some(x.to_string()))
+ .collect();
+
+ let concat = concat_dictionary(input_1, input_2);
+ assert_eq!(concat, expected);
+ }
+
+ #[test]
+ fn test_string_dictionary_array_nulls() {
+ let input_1: DictionaryArray<Int32Type> =
+ vec![Some("foo"), Some("bar"), None, Some("fiz")]
+ .into_iter()
+ .collect();
+ let input_2: DictionaryArray<Int32Type> = vec![None].into_iter().collect();
+ let expected = vec![
+ Some("foo".to_string()),
+ Some("bar".to_string()),
+ None,
+ Some("fiz".to_string()),
+ None,
+ ];
+
+ let concat = concat_dictionary(input_1, input_2);
+ assert_eq!(concat, expected);
+ }
}