You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ne...@apache.org on 2020/11/07 11:25:00 UTC
[arrow] branch master updated: ARROW-10378: [Rust] Update take()
kernel with support for LargeList.
This is an automated email from the ASF dual-hosted git repository.
nevime pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new e6366dc ARROW-10378: [Rust] Update take() kernel with support for LargeList.
e6366dc is described below
commit e6366dc753e2e121435f7936d059961702842445
Author: Daniel Russo <da...@gmail.com>
AuthorDate: Sat Nov 7 13:22:47 2020 +0200
ARROW-10378: [Rust] Update take() kernel with support for LargeList.
This change adds support for `LargeList` in `take()`.
There is an additional update to the underlying implementation of `take()` such that the indices may be any `PrimitiveArray` of `ArrowNumericType`, rather than only `UInt32Array`. This change is motivated by the recursive call to `take()` in `take_list()` ([here](https://github.com/apache/arrow/blob/b109195b77d85e513aab80650bd4b193e26a5471/rust/arrow/src/compute/kernels/take.rs#L324)), since in order to support `LargeListArray`, which use `i64` offsets, the recursive call must support [...]
Closes #8556 from drusso/ARROW-10378
Authored-by: Daniel Russo <da...@gmail.com>
Signed-off-by: Neville Dipale <ne...@gmail.com>
---
rust/arrow/src/compute/kernels/take.rs | 711 ++++++++++++++++++++++-----------
rust/arrow/src/compute/util.rs | 122 ++++--
2 files changed, 567 insertions(+), 266 deletions(-)
diff --git a/rust/arrow/src/compute/kernels/take.rs b/rust/arrow/src/compute/kernels/take.rs
index b61e393..9cb7f36 100644
--- a/rust/arrow/src/compute/kernels/take.rs
+++ b/rust/arrow/src/compute/kernels/take.rs
@@ -26,7 +26,7 @@ use crate::error::{ArrowError, Result};
use crate::util::bit_util;
use crate::{array::*, buffer::buffer_bin_and};
-use num::Zero;
+use num::{ToPrimitive, Zero};
use TimeUnit::*;
/// Take elements from `ArrayRef` by copying the data from `values` at
@@ -57,12 +57,26 @@ pub fn take(
indices: &UInt32Array,
options: Option<TakeOptions>,
) -> Result<ArrayRef> {
+ take_impl::<UInt32Type>(values, indices, options)
+}
+
+fn take_impl<IndexType>(
+ values: &ArrayRef,
+ indices: &PrimitiveArray<IndexType>,
+ options: Option<TakeOptions>,
+) -> Result<ArrayRef>
+where
+ IndexType: ArrowNumericType,
+ IndexType::Native: ToPrimitive,
+{
let options = options.unwrap_or_default();
if options.check_bounds {
let len = values.len();
for i in 0..indices.len() {
if indices.is_valid(i) {
- let ix = indices.value(i) as usize;
+ let ix = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| {
+ ArrowError::ComputeError("Cast to usize failed".to_string())
+ })?;
if ix >= len {
return Err(ArrowError::ComputeError(
format!("Array index out of bounds, cannot get item at index {} from {} entries", ix, len))
@@ -73,68 +87,71 @@ pub fn take(
}
match values.data_type() {
DataType::Boolean => take_boolean(values, indices),
- DataType::Int8 => take_primitive::<Int8Type>(values, indices),
- DataType::Int16 => take_primitive::<Int16Type>(values, indices),
- DataType::Int32 => take_primitive::<Int32Type>(values, indices),
- DataType::Int64 => take_primitive::<Int64Type>(values, indices),
- DataType::UInt8 => take_primitive::<UInt8Type>(values, indices),
- DataType::UInt16 => take_primitive::<UInt16Type>(values, indices),
- DataType::UInt32 => take_primitive::<UInt32Type>(values, indices),
- DataType::UInt64 => take_primitive::<UInt64Type>(values, indices),
- DataType::Float32 => take_primitive::<Float32Type>(values, indices),
- DataType::Float64 => take_primitive::<Float64Type>(values, indices),
- DataType::Date32(_) => take_primitive::<Date32Type>(values, indices),
- DataType::Date64(_) => take_primitive::<Date64Type>(values, indices),
- DataType::Time32(Second) => take_primitive::<Time32SecondType>(values, indices),
+ DataType::Int8 => take_primitive::<Int8Type, _>(values, indices),
+ DataType::Int16 => take_primitive::<Int16Type, _>(values, indices),
+ DataType::Int32 => take_primitive::<Int32Type, _>(values, indices),
+ DataType::Int64 => take_primitive::<Int64Type, _>(values, indices),
+ DataType::UInt8 => take_primitive::<UInt8Type, _>(values, indices),
+ DataType::UInt16 => take_primitive::<UInt16Type, _>(values, indices),
+ DataType::UInt32 => take_primitive::<UInt32Type, _>(values, indices),
+ DataType::UInt64 => take_primitive::<UInt64Type, _>(values, indices),
+ DataType::Float32 => take_primitive::<Float32Type, _>(values, indices),
+ DataType::Float64 => take_primitive::<Float64Type, _>(values, indices),
+ DataType::Date32(_) => take_primitive::<Date32Type, _>(values, indices),
+ DataType::Date64(_) => take_primitive::<Date64Type, _>(values, indices),
+ DataType::Time32(Second) => {
+ take_primitive::<Time32SecondType, _>(values, indices)
+ }
DataType::Time32(Millisecond) => {
- take_primitive::<Time32MillisecondType>(values, indices)
+ take_primitive::<Time32MillisecondType, _>(values, indices)
}
DataType::Time64(Microsecond) => {
- take_primitive::<Time64MicrosecondType>(values, indices)
+ take_primitive::<Time64MicrosecondType, _>(values, indices)
}
DataType::Time64(Nanosecond) => {
- take_primitive::<Time64NanosecondType>(values, indices)
+ take_primitive::<Time64NanosecondType, _>(values, indices)
}
DataType::Timestamp(Second, _) => {
- take_primitive::<TimestampSecondType>(values, indices)
+ take_primitive::<TimestampSecondType, _>(values, indices)
}
DataType::Timestamp(Millisecond, _) => {
- take_primitive::<TimestampMillisecondType>(values, indices)
+ take_primitive::<TimestampMillisecondType, _>(values, indices)
}
DataType::Timestamp(Microsecond, _) => {
- take_primitive::<TimestampMicrosecondType>(values, indices)
+ take_primitive::<TimestampMicrosecondType, _>(values, indices)
}
DataType::Timestamp(Nanosecond, _) => {
- take_primitive::<TimestampNanosecondType>(values, indices)
+ take_primitive::<TimestampNanosecondType, _>(values, indices)
}
DataType::Interval(IntervalUnit::YearMonth) => {
- take_primitive::<IntervalYearMonthType>(values, indices)
+ take_primitive::<IntervalYearMonthType, _>(values, indices)
}
DataType::Interval(IntervalUnit::DayTime) => {
- take_primitive::<IntervalDayTimeType>(values, indices)
+ take_primitive::<IntervalDayTimeType, _>(values, indices)
}
DataType::Duration(TimeUnit::Second) => {
- take_primitive::<DurationSecondType>(values, indices)
+ take_primitive::<DurationSecondType, _>(values, indices)
}
DataType::Duration(TimeUnit::Millisecond) => {
- take_primitive::<DurationMillisecondType>(values, indices)
+ take_primitive::<DurationMillisecondType, _>(values, indices)
}
DataType::Duration(TimeUnit::Microsecond) => {
- take_primitive::<DurationMicrosecondType>(values, indices)
+ take_primitive::<DurationMicrosecondType, _>(values, indices)
}
DataType::Duration(TimeUnit::Nanosecond) => {
- take_primitive::<DurationNanosecondType>(values, indices)
+ take_primitive::<DurationNanosecondType, _>(values, indices)
}
- DataType::Utf8 => take_string::<i32>(values, indices),
- DataType::LargeUtf8 => take_string::<i64>(values, indices),
- DataType::List(_) => take_list(values, indices),
+ DataType::Utf8 => take_string::<i32, _>(values, indices),
+ DataType::LargeUtf8 => take_string::<i64, _>(values, indices),
+ DataType::List(_) => take_list::<_, Int32Type>(values, indices),
+ DataType::LargeList(_) => take_list::<_, Int64Type>(values, indices),
DataType::Struct(fields) => {
let struct_: &StructArray =
values.as_any().downcast_ref::<StructArray>().unwrap();
let arrays: Result<Vec<ArrayRef>> = struct_
.columns()
.iter()
- .map(|a| take(a, indices, Some(options.clone())))
+ .map(|a| take_impl(a, indices, Some(options.clone())))
.collect();
let arrays = arrays?;
let pairs: Vec<(Field, ArrayRef)> =
@@ -142,14 +159,14 @@ pub fn take(
Ok(Arc::new(StructArray::from(pairs)) as ArrayRef)
}
DataType::Dictionary(key_type, _) => match key_type.as_ref() {
- DataType::Int8 => take_dict::<Int8Type>(values, indices),
- DataType::Int16 => take_dict::<Int16Type>(values, indices),
- DataType::Int32 => take_dict::<Int32Type>(values, indices),
- DataType::Int64 => take_dict::<Int64Type>(values, indices),
- DataType::UInt8 => take_dict::<UInt8Type>(values, indices),
- DataType::UInt16 => take_dict::<UInt16Type>(values, indices),
- DataType::UInt32 => take_dict::<UInt32Type>(values, indices),
- DataType::UInt64 => take_dict::<UInt64Type>(values, indices),
+ DataType::Int8 => take_dict::<Int8Type, _>(values, indices),
+ DataType::Int16 => take_dict::<Int16Type, _>(values, indices),
+ DataType::Int32 => take_dict::<Int32Type, _>(values, indices),
+ DataType::Int64 => take_dict::<Int64Type, _>(values, indices),
+ DataType::UInt8 => take_dict::<UInt8Type, _>(values, indices),
+ DataType::UInt16 => take_dict::<UInt16Type, _>(values, indices),
+ DataType::UInt32 => take_dict::<UInt32Type, _>(values, indices),
+ DataType::UInt64 => take_dict::<UInt64Type, _>(values, indices),
t => unimplemented!("Take not supported for dictionary key type {:?}", t),
},
t => unimplemented!("Take not supported for data type {:?}", t),
@@ -182,9 +199,14 @@ impl Default for TakeOptions {
/// values: [1, 2, 3, null, 5]
/// indices: [0, null, 4, 3]
/// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)]
-fn take_primitive<T>(values: &ArrayRef, indices: &UInt32Array) -> Result<ArrayRef>
+fn take_primitive<T, I>(
+ values: &ArrayRef,
+ indices: &PrimitiveArray<I>,
+) -> Result<ArrayRef>
where
T: ArrowPrimitiveType,
+ I: ArrowNumericType,
+ I::Native: ToPrimitive,
{
let data_len = indices.len();
@@ -195,15 +217,23 @@ where
let null_slice = null_buf.data_mut();
- let new_values: Vec<T::Native> = (0..data_len)
- .map(|i| {
- let index = indices.value(i) as usize;
- if array.is_null(index) {
- bit_util::unset_bit(null_slice, i);
- }
- array.value(index)
- })
- .collect();
+ // This iteration is implemented with a while loop, rather than a
+ // map()/collect(), since the while loop performs better in the benchmarks.
+ let mut new_values: Vec<T::Native> = Vec::with_capacity(data_len);
+ let mut i = 0;
+ while i < data_len {
+ let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| {
+ ArrowError::ComputeError("Cast to usize failed".to_string())
+ })?;
+
+ if array.is_null(index) {
+ bit_util::unset_bit(null_slice, i);
+ }
+
+ new_values.push(array.value(index));
+
+ i += 1;
+ }
let nulls = match indices.data_ref().null_buffer() {
Some(buffer) => buffer_bin_and(buffer, 0, &null_buf.freeze(), 0, indices.len()),
@@ -223,7 +253,14 @@ where
}
/// `take` implementation for boolean arrays
-fn take_boolean(values: &ArrayRef, indices: &UInt32Array) -> Result<ArrayRef> {
+fn take_boolean<IndexType>(
+ values: &ArrayRef,
+ indices: &PrimitiveArray<IndexType>,
+) -> Result<ArrayRef>
+where
+ IndexType: ArrowNumericType,
+ IndexType::Native: ToPrimitive,
+{
let data_len = indices.len();
let array = values.as_any().downcast_ref::<BooleanArray>().unwrap();
@@ -235,14 +272,19 @@ fn take_boolean(values: &ArrayRef, indices: &UInt32Array) -> Result<ArrayRef> {
let null_slice = null_buf.data_mut();
let val_slice = val_buf.data_mut();
- (0..data_len).for_each(|i| {
- let index = indices.value(i) as usize;
+ (0..data_len).try_for_each::<_, Result<()>>(|i| {
+ let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| {
+ ArrowError::ComputeError("Cast to usize failed".to_string())
+ })?;
+
if array.is_null(index) {
bit_util::unset_bit(null_slice, i);
} else if array.value(index) {
bit_util::set_bit(val_slice, i);
}
- });
+
+ Ok(())
+ })?;
let nulls = match indices.data_ref().null_buffer() {
Some(buffer) => buffer_bin_and(buffer, 0, &null_buf.freeze(), 0, indices.len()),
@@ -262,9 +304,14 @@ fn take_boolean(values: &ArrayRef, indices: &UInt32Array) -> Result<ArrayRef> {
}
/// `take` implementation for string arrays
-fn take_string<OffsetSize>(values: &ArrayRef, indices: &UInt32Array) -> Result<ArrayRef>
+fn take_string<OffsetSize, IndexType>(
+ values: &ArrayRef,
+ indices: &PrimitiveArray<IndexType>,
+) -> Result<ArrayRef>
where
OffsetSize: Zero + AddAssign + StringOffsetSizeTrait,
+ IndexType: ArrowNumericType,
+ IndexType::Native: ToPrimitive,
{
let data_len = indices.len();
@@ -283,7 +330,9 @@ where
offsets.push(length_so_far);
for i in 0..data_len {
- let index = indices.value(i) as usize;
+ let index = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| {
+ ArrowError::ComputeError("Cast to usize failed".to_string())
+ })?;
if array.is_valid(index) && indices.is_valid(i) {
let s = array.value(index);
@@ -316,28 +365,43 @@ where
/// Calculates the index and indexed offset for the inner array,
/// applying `take` on the inner array, then reconstructing a list array
/// with the indexed offsets
-fn take_list(values: &ArrayRef, indices: &UInt32Array) -> Result<ArrayRef> {
+fn take_list<IndexType, OffsetType>(
+ values: &ArrayRef,
+ indices: &PrimitiveArray<IndexType>,
+) -> Result<ArrayRef>
+where
+ IndexType: ArrowNumericType,
+ IndexType::Native: ToPrimitive,
+ OffsetType: ArrowNumericType,
+ OffsetType::Native: ToPrimitive + OffsetSizeTrait,
+ PrimitiveArray<OffsetType>: From<Vec<Option<OffsetType::Native>>>,
+{
// TODO: Some optimizations can be done here such as if it is
// taking the whole list or a contiguous sublist
- let list: &ListArray = values.as_any().downcast_ref::<ListArray>().unwrap();
- let (list_indices, offsets) = take_value_indices_from_list(values, indices);
- let taken = take(&list.values(), &list_indices, None)?;
+ let list = values
+ .as_any()
+ .downcast_ref::<GenericListArray<OffsetType::Native>>()
+ .unwrap();
+
+ let (list_indices, offsets) =
+ take_value_indices_from_list::<IndexType, OffsetType>(values, indices)?;
+
+ let taken = take_impl::<OffsetType>(&list.values(), &list_indices, None)?;
// determine null count and null buffer, which are a function of `values` and `indices`
let mut null_count = 0;
let num_bytes = bit_util::ceil(indices.len(), 8);
let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
{
let null_slice = null_buf.data_mut();
- offsets[..]
- .windows(2)
- .enumerate()
- .for_each(|(i, window): (usize, &[i32])| {
+ offsets[..].windows(2).enumerate().for_each(
+ |(i, window): (usize, &[OffsetType::Native])| {
if window[0] == window[1] {
// offsets are equal, slot is null
bit_util::unset_bit(null_slice, i);
null_count += 1;
}
- });
+ },
+ );
}
let value_offsets = Buffer::from(offsets[..].to_byte_slice());
// create a new list with taken data and computed null information
@@ -349,7 +413,8 @@ fn take_list(values: &ArrayRef, indices: &UInt32Array) -> Result<ArrayRef> {
.add_child_data(taken.data())
.add_buffer(value_offsets)
.build();
- let list_array = Arc::new(ListArray::from(list_data)) as ArrayRef;
+ let list_array =
+ Arc::new(GenericListArray::<OffsetType::Native>::from(list_data)) as ArrayRef;
Ok(list_array)
}
@@ -357,16 +422,18 @@ fn take_list(values: &ArrayRef, indices: &UInt32Array) -> Result<ArrayRef> {
///
/// applies `take` to the keys of the dictionary array and returns a new dictionary array
/// with the same dictionary values and reordered keys
-fn take_dict<T>(values: &ArrayRef, indices: &UInt32Array) -> Result<ArrayRef>
+fn take_dict<T, I>(values: &ArrayRef, indices: &PrimitiveArray<I>) -> Result<ArrayRef>
where
T: ArrowPrimitiveType,
+ I: ArrowNumericType,
+ I::Native: ToPrimitive,
{
let dict = values
.as_any()
.downcast_ref::<DictionaryArray<T>>()
.unwrap();
let keys: ArrayRef = Arc::new(dict.keys_array());
- let new_keys = take_primitive::<T>(&keys, indices)?;
+ let new_keys = take_primitive::<T, I>(&keys, indices)?;
let new_keys_data = new_keys.data_ref();
let data = Arc::new(ArrayData::new(
@@ -405,6 +472,27 @@ mod tests {
)
}
+ fn test_take_impl_primitive_arrays<T, I>(
+ data: Vec<Option<T::Native>>,
+ index: &PrimitiveArray<I>,
+ options: Option<TakeOptions>,
+ expected_data: Vec<Option<T::Native>>,
+ ) where
+ T: ArrowPrimitiveType,
+ PrimitiveArray<T>: From<Vec<Option<T::Native>>> + ArrayEqual,
+ I: ArrowNumericType,
+ I::Native: ToPrimitive,
+ {
+ let output = PrimitiveArray::<T>::from(data);
+ let expected = PrimitiveArray::<T>::from(expected_data);
+ let output = take_impl(&(Arc::new(output) as ArrayRef), index, options).unwrap();
+ let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+ assert!(
+ output.equals(&expected),
+ format!("{:?} =! {:?}", output.data(), expected.data())
+ )
+ }
+
// create a simple struct for testing purposes
fn create_test_struct() -> ArrayRef {
let boolean_data = BooleanArray::from(vec![true, false, false, true]).data();
@@ -426,6 +514,38 @@ mod tests {
fn test_take_primitive() {
let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
+ // int8
+ test_take_primitive_arrays::<Int8Type>(
+ vec![Some(0), None, Some(2), Some(3), None],
+ &index,
+ None,
+ vec![Some(3), None, None, Some(3), Some(2)],
+ );
+
+ // int16
+ test_take_primitive_arrays::<Int16Type>(
+ vec![Some(0), None, Some(2), Some(3), None],
+ &index,
+ None,
+ vec![Some(3), None, None, Some(3), Some(2)],
+ );
+
+ // int32
+ test_take_primitive_arrays::<Int32Type>(
+ vec![Some(0), None, Some(2), Some(3), None],
+ &index,
+ None,
+ vec![Some(3), None, None, Some(3), Some(2)],
+ );
+
+ // int64
+ test_take_primitive_arrays::<Int64Type>(
+ vec![Some(0), None, Some(2), Some(3), None],
+ &index,
+ None,
+ vec![Some(3), None, None, Some(3), Some(2)],
+ );
+
// uint8
test_take_primitive_arrays::<UInt8Type>(
vec![Some(0), None, Some(2), Some(3), None],
@@ -524,6 +644,80 @@ mod tests {
}
#[test]
+ fn test_take_impl_primitive_with_int64_indices() {
+ let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
+
+ // int16
+ test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
+ vec![Some(0), None, Some(2), Some(3), None],
+ &index,
+ None,
+ vec![Some(3), None, None, Some(3), Some(2)],
+ );
+
+ // int64
+ test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
+ vec![Some(0), None, Some(2), Some(-15), None],
+ &index,
+ None,
+ vec![Some(-15), None, None, Some(-15), Some(2)],
+ );
+
+ // uint64
+ test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
+ vec![Some(0), None, Some(2), Some(3), None],
+ &index,
+ None,
+ vec![Some(3), None, None, Some(3), Some(2)],
+ );
+
+ // duration_millisecond
+ test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
+ vec![Some(0), None, Some(2), Some(-15), None],
+ &index,
+ None,
+ vec![Some(-15), None, None, Some(-15), Some(2)],
+ );
+
+ // float32
+ test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
+ vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
+ &index,
+ None,
+ vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
+ );
+ }
+
+ #[test]
+ fn test_take_impl_primitive_with_uint8_indices() {
+ let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
+
+ // int16
+ test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
+ vec![Some(0), None, Some(2), Some(3), None],
+ &index,
+ None,
+ vec![Some(3), None, None, Some(3), Some(2)],
+ );
+
+ // duration_millisecond
+ test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
+ vec![Some(0), None, Some(2), Some(-15), None],
+ &index,
+ None,
+ vec![Some(-15), None, None, Some(-15), Some(2)],
+ );
+
+ // float32
+ test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
+ vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
+ &index,
+ None,
+ vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
+ );
+ }
+
+ #[test]
fn test_take_primitive_bool() {
let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
// boolean
@@ -576,186 +770,255 @@ mod tests {
_test_take_string::<LargeStringArray>()
}
- #[test]
- fn test_take_list() {
- // Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]]
- let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).data();
- // Construct offsets
- let value_offsets = Buffer::from(&[0, 3, 6, 8].to_byte_slice());
- // Construct a list array from the above two
- let list_data_type = DataType::List(Box::new(DataType::Int32));
- let list_data = ArrayData::builder(list_data_type.clone())
- .len(3)
- .add_buffer(value_offsets)
- .add_child_data(value_data)
- .build();
- let list_array = Arc::new(ListArray::from(list_data)) as ArrayRef;
+ macro_rules! test_take_list {
+ ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
+ // Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]]
+ let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).data();
+ // Construct offsets
+ let value_offsets: [$offset_type; 4] = [0, 3, 6, 8];
+ let value_offsets = Buffer::from(&value_offsets.to_byte_slice());
+ // Construct a list array from the above two
+ let list_data_type = DataType::$list_data_type(Box::new(DataType::Int32));
+ let list_data = ArrayData::builder(list_data_type.clone())
+ .len(3)
+ .add_buffer(value_offsets)
+ .add_child_data(value_data)
+ .build();
+ let list_array = Arc::new($list_array_type::from(list_data)) as ArrayRef;
+
+ // index returns: [[2,3], null, [-1,-2,-1], [2,3], [0,0,0]]
+ let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(2), Some(0)]);
+
+ let a = take(&list_array, &index, None).unwrap();
+ let a: &$list_array_type =
+ a.as_any().downcast_ref::<$list_array_type>().unwrap();
+
+ // construct a value array with expected results:
+ // [[2,3], null, [-1,-2,-1], [2,3], [0,0,0]]
+ let expected_data = Int32Array::from(vec![
+ Some(2),
+ Some(3),
+ Some(-1),
+ Some(-2),
+ Some(-1),
+ Some(2),
+ Some(3),
+ Some(0),
+ Some(0),
+ Some(0),
+ ])
+ .data();
+ // construct offsets
+ let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 7, 10];
+ let expected_offsets = Buffer::from(&expected_offsets.to_byte_slice());
+ // construct list array from the two
+ let expected_list_data = ArrayData::builder(list_data_type)
+ .len(5)
+ .null_count(1)
+ // null buffer remains the same as only the indices have nulls
+ .null_bit_buffer(
+ index.data().null_bitmap().as_ref().unwrap().bits.clone(),
+ )
+ .add_buffer(expected_offsets)
+ .add_child_data(expected_data)
+ .build();
+ let expected_list_array = $list_array_type::from(expected_list_data);
+
+ assert!(a.equals(&expected_list_array));
+ }};
+ }
- // index returns: [[2,3], null, [-1,-2,-1], [2,3], [0,0,0]]
- let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(2), Some(0)]);
+ macro_rules! test_take_list_with_value_nulls {
+ ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
+ // Construct a value array, [[0,null,0], [-1,-2,3], [null], [5,null]]
+ let value_data = Int32Array::from(vec![
+ Some(0),
+ None,
+ Some(0),
+ Some(-1),
+ Some(-2),
+ Some(3),
+ None,
+ Some(5),
+ None,
+ ])
+ .data();
+ // Construct offsets
+ let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
+ let value_offsets = Buffer::from(&value_offsets.to_byte_slice());
+ // Construct a list array from the above two
+ let list_data_type = DataType::$list_data_type(Box::new(DataType::Int32));
+ let list_data = ArrayData::builder(list_data_type.clone())
+ .len(4)
+ .add_buffer(value_offsets)
+ .null_count(0)
+ .null_bit_buffer(Buffer::from([0b10111101, 0b00000000]))
+ .add_child_data(value_data)
+ .build();
+ let list_array = Arc::new($list_array_type::from(list_data)) as ArrayRef;
+
+ // index returns: [[null], null, [-1,-2,3], [2,null], [0,null,0]]
+ let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
+
+ let a = take(&list_array, &index, None).unwrap();
+ let a: &$list_array_type =
+ a.as_any().downcast_ref::<$list_array_type>().unwrap();
+
+ // construct a value array with expected results:
+ // [[null], null, [-1,-2,3], [5,null], [0,null,0]]
+ let expected_data = Int32Array::from(vec![
+ None,
+ Some(-1),
+ Some(-2),
+ Some(3),
+ Some(5),
+ None,
+ Some(0),
+ None,
+ Some(0),
+ ])
+ .data();
+ // construct offsets
+ let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
+ let expected_offsets = Buffer::from(&expected_offsets.to_byte_slice());
+ // construct list array from the two
+ let expected_list_data = ArrayData::builder(list_data_type)
+ .len(5)
+ .null_count(1)
+ // null buffer remains the same as only the indices have nulls
+ .null_bit_buffer(
+ index.data().null_bitmap().as_ref().unwrap().bits.clone(),
+ )
+ .add_buffer(expected_offsets)
+ .add_child_data(expected_data)
+ .build();
+ let expected_list_array = $list_array_type::from(expected_list_data);
+
+ assert!(a.equals(&expected_list_array));
+ }};
+ }
- let a = take(&list_array, &index, None).unwrap();
- let a: &ListArray = a.as_any().downcast_ref::<ListArray>().unwrap();
+ macro_rules! test_take_list_with_nulls {
+ ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
+ // Construct a value array, [[0,null,0], [-1,-2,3], null, [5,null]]
+ let value_data = Int32Array::from(vec![
+ Some(0),
+ None,
+ Some(0),
+ Some(-1),
+ Some(-2),
+ Some(3),
+ Some(5),
+ None,
+ ])
+ .data();
+ // Construct offsets
+ let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
+ let value_offsets = Buffer::from(&value_offsets.to_byte_slice());
+ // Construct a list array from the above two
+ let list_data_type = DataType::$list_data_type(Box::new(DataType::Int32));
+ let list_data = ArrayData::builder(list_data_type.clone())
+ .len(4)
+ .add_buffer(value_offsets)
+ .null_count(1)
+ .null_bit_buffer(Buffer::from([0b01111101]))
+ .add_child_data(value_data)
+ .build();
+ let list_array = Arc::new($list_array_type::from(list_data)) as ArrayRef;
+
+ // index returns: [null, null, [-1,-2,3], [5,null], [0,null,0]]
+ let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
+
+ let a = take(&list_array, &index, None).unwrap();
+ let a: &$list_array_type =
+ a.as_any().downcast_ref::<$list_array_type>().unwrap();
+
+ // construct a value array with expected results:
+ // [null, null, [-1,-2,3], [5,null], [0,null,0]]
+ let expected_data = Int32Array::from(vec![
+ Some(-1),
+ Some(-2),
+ Some(3),
+ Some(5),
+ None,
+ Some(0),
+ None,
+ Some(0),
+ ])
+ .data();
+ // construct offsets
+ let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
+ let expected_offsets = Buffer::from(&expected_offsets.to_byte_slice());
+ // construct list array from the two
+ let mut null_bits: [u8; 1] = [0; 1];
+ bit_util::set_bit(&mut null_bits, 2);
+ bit_util::set_bit(&mut null_bits, 3);
+ bit_util::set_bit(&mut null_bits, 4);
+ let expected_list_data = ArrayData::builder(list_data_type)
+ .len(5)
+ .null_count(2)
+ // null buffer must be recalculated as both values and indices have nulls
+ .null_bit_buffer(Buffer::from(null_bits))
+ .add_buffer(expected_offsets)
+ .add_child_data(expected_data)
+ .build();
+ let expected_list_array = $list_array_type::from(expected_list_data);
+
+ assert!(a.equals(&expected_list_array));
+ }};
+ }
- // construct a value array with expected results:
- // [[2,3], null, [-1,-2,-1], [2,3], [0,0,0]]
- let expected_data = Int32Array::from(vec![
- Some(2),
- Some(3),
- Some(-1),
- Some(-2),
- Some(-1),
- Some(2),
- Some(3),
- Some(0),
- Some(0),
- Some(0),
- ])
- .data();
- // construct offsets
- let expected_offsets = Buffer::from(&[0, 2, 2, 5, 7, 10].to_byte_slice());
- // construct list array from the two
- let expected_list_data = ArrayData::builder(list_data_type)
- .len(5)
- .null_count(1)
- // null buffer remains the same as only the indices have nulls
- .null_bit_buffer(index.data().null_bitmap().as_ref().unwrap().bits.clone())
- .add_buffer(expected_offsets)
- .add_child_data(expected_data)
- .build();
- let expected_list_array = ListArray::from(expected_list_data);
+ #[test]
+ fn test_take_list() {
+ test_take_list!(i32, List, ListArray);
+ }
- assert!(a.equals(&expected_list_array));
+ #[test]
+ fn test_take_large_list() {
+ test_take_list!(i64, LargeList, LargeListArray);
}
#[test]
fn test_take_list_with_value_nulls() {
- // Construct a value array, [[0,null,0], [-1,-2,3], [null], [5,null]]
- let value_data = Int32Array::from(vec![
- Some(0),
- None,
- Some(0),
- Some(-1),
- Some(-2),
- Some(3),
- None,
- Some(5),
- None,
- ])
- .data();
- // Construct offsets
- let value_offsets = Buffer::from(&[0, 3, 6, 7, 9].to_byte_slice());
- // Construct a list array from the above two
- let list_data_type = DataType::List(Box::new(DataType::Int32));
- let list_data = ArrayData::builder(list_data_type.clone())
- .len(4)
- .add_buffer(value_offsets)
- .null_count(0)
- .null_bit_buffer(Buffer::from([0b10111101, 0b00000000]))
- .add_child_data(value_data)
- .build();
- let list_array = Arc::new(ListArray::from(list_data)) as ArrayRef;
-
- // index returns: [[null], null, [-1,-2,3], [2,null], [0,null,0]]
- let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
+ test_take_list_with_value_nulls!(i32, List, ListArray);
+ }
- let a = take(&list_array, &index, None).unwrap();
- let a: &ListArray = a.as_any().downcast_ref::<ListArray>().unwrap();
+ #[test]
+ fn test_take_large_list_with_value_nulls() {
+ test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
+ }
- // construct a value array with expected results:
- // [[null], null, [-1,-2,3], [5,null], [0,null,0]]
- let expected_data = Int32Array::from(vec![
- None,
- Some(-1),
- Some(-2),
- Some(3),
- Some(5),
- None,
- Some(0),
- None,
- Some(0),
- ])
- .data();
- // construct offsets
- let expected_offsets = Buffer::from(&[0, 1, 1, 4, 6, 9].to_byte_slice());
- // construct list array from the two
- let expected_list_data = ArrayData::builder(list_data_type)
- .len(5)
- .null_count(1)
- // null buffer remains the same as only the indices have nulls
- .null_bit_buffer(index.data().null_bitmap().as_ref().unwrap().bits.clone())
- .add_buffer(expected_offsets)
- .add_child_data(expected_data)
- .build();
- let expected_list_array = ListArray::from(expected_list_data);
+ #[test]
+ fn test_test_take_list_with_nulls() {
+ test_take_list_with_nulls!(i32, List, ListArray);
+ }
- assert!(a.equals(&expected_list_array));
+ #[test]
+ fn test_test_take_large_list_with_nulls() {
+ test_take_list_with_nulls!(i64, LargeList, LargeListArray);
}
#[test]
- fn test_take_list_with_list_nulls() {
- // Construct a value array, [[0,null,0], [-1,-2,3], null, [5,null]]
- let value_data = Int32Array::from(vec![
- Some(0),
- None,
- Some(0),
- Some(-1),
- Some(-2),
- Some(3),
- Some(5),
- None,
- ])
- .data();
+ #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
+ fn test_take_list_out_of_bounds() {
+ // Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]]
+ let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).data();
// Construct offsets
- let value_offsets = Buffer::from(&[0, 3, 6, 6, 8].to_byte_slice());
+ let value_offsets = Buffer::from(&[0, 3, 6, 8].to_byte_slice());
// Construct a list array from the above two
let list_data_type = DataType::List(Box::new(DataType::Int32));
let list_data = ArrayData::builder(list_data_type.clone())
- .len(4)
+ .len(3)
.add_buffer(value_offsets)
- .null_count(1)
- .null_bit_buffer(Buffer::from([0b01111101]))
.add_child_data(value_data)
.build();
let list_array = Arc::new(ListArray::from(list_data)) as ArrayRef;
- // index returns: [null, null, [-1,-2,3], [5,null], [0,null,0]]
- let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
-
- let a = take(&list_array, &index, None).unwrap();
- let a: &ListArray = a.as_any().downcast_ref::<ListArray>().unwrap();
-
- // construct a value array with expected results:
- // [null, null, [-1,-2,3], [5,null], [0,null,0]]
- let expected_data = Int32Array::from(vec![
- Some(-1),
- Some(-2),
- Some(3),
- Some(5),
- None,
- Some(0),
- None,
- Some(0),
- ])
- .data();
- // construct offsets
- let expected_offsets = Buffer::from(&[0, 0, 0, 3, 5, 8].to_byte_slice());
- // construct list array from the two
- let mut null_bits: [u8; 1] = [0; 1];
- bit_util::set_bit(&mut null_bits, 2);
- bit_util::set_bit(&mut null_bits, 3);
- bit_util::set_bit(&mut null_bits, 4);
- let expected_list_data = ArrayData::builder(list_data_type)
- .len(5)
- .null_count(2)
- // null buffer must be recalculated as both values and indices have nulls
- .null_bit_buffer(Buffer::from(null_bits))
- .add_buffer(expected_offsets)
- .add_child_data(expected_data)
- .build();
- let expected_list_array = ListArray::from(expected_list_data);
+ let index = UInt32Array::from(vec![1000]);
- assert!(a.equals(&expected_list_array));
+ // A panic is expected here since we have not supplied the check_bounds
+ // option.
+ take(&list_array, &index, None).unwrap();
}
#[test]
diff --git a/rust/arrow/src/compute/util.rs b/rust/arrow/src/compute/util.rs
index e499dc3..686cc8b 100644
--- a/rust/arrow/src/compute/util.rs
+++ b/rust/arrow/src/compute/util.rs
@@ -21,13 +21,12 @@ use crate::array::*;
#[cfg(feature = "simd")]
use crate::bitmap::Bitmap;
use crate::buffer::{buffer_bin_and, buffer_bin_or, Buffer};
-#[cfg(feature = "simd")]
use crate::datatypes::*;
-use crate::error::Result;
-#[cfg(feature = "simd")]
-use num::One;
+use crate::error::{ArrowError, Result};
+use num::{One, ToPrimitive, Zero};
#[cfg(feature = "simd")]
use std::cmp::min;
+use std::ops::Add;
/// Combines the null bitmaps of two arrays using a bitwise `and` operation.
///
@@ -100,41 +99,55 @@ pub(super) fn compare_option_bitmap(
/// Where a list array has indices `[0,2,5,10]`, taking indices of `[2,0]` returns
/// an array of the indices `[5..10, 0..2]` and offsets `[0,5,7]` (5 elements and 2
/// elements)
-pub(super) fn take_value_indices_from_list(
+pub(super) fn take_value_indices_from_list<IndexType, OffsetType>(
values: &ArrayRef,
- indices: &UInt32Array,
-) -> (UInt32Array, Vec<i32>) {
+ indices: &PrimitiveArray<IndexType>,
+) -> Result<(PrimitiveArray<OffsetType>, Vec<OffsetType::Native>)>
+where
+ IndexType: ArrowNumericType,
+ IndexType::Native: ToPrimitive,
+ OffsetType: ArrowNumericType,
+ OffsetType::Native: OffsetSizeTrait + Add + Zero + One,
+ PrimitiveArray<OffsetType>: From<Vec<Option<OffsetType::Native>>>,
+{
// TODO: benchmark this function, there might be a faster unsafe alternative
// get list array's offsets
- let list: &ListArray = values.as_any().downcast_ref::<ListArray>().unwrap();
- let offsets: Vec<u32> = (0..=list.len())
- .map(|i| list.value_offset(i) as u32)
- .collect();
+ let list = values
+ .as_any()
+ .downcast_ref::<GenericListArray<OffsetType::Native>>()
+ .unwrap();
+ let offsets: Vec<OffsetType::Native> =
+ (0..=list.len()).map(|i| list.value_offset(i)).collect();
+
let mut new_offsets = Vec::with_capacity(indices.len());
let mut values = Vec::new();
- let mut current_offset = 0;
+ let mut current_offset = OffsetType::Native::zero();
// add first offset
- new_offsets.push(0);
+ new_offsets.push(OffsetType::Native::zero());
// compute the value indices, and set offsets accordingly
for i in 0..indices.len() {
if indices.is_valid(i) {
- let ix = indices.value(i) as usize;
+ let ix = ToPrimitive::to_usize(&indices.value(i)).ok_or_else(|| {
+ ArrowError::ComputeError("Cast to usize failed".to_string())
+ })?;
let start = offsets[ix];
let end = offsets[ix + 1];
- current_offset += (end - start) as i32;
+ current_offset = current_offset + (end - start);
new_offsets.push(current_offset);
+
+ let mut curr = start;
+
// if start == end, this slot is empty
- if start != end {
- // type annotation needed to guide compiler a bit
- let mut offsets: Vec<Option<u32>> =
- (start..end).map(Some).collect::<Vec<Option<u32>>>();
- values.append(&mut offsets);
+ while curr < end {
+ values.push(Some(curr));
+ curr = curr + OffsetType::Native::one();
}
} else {
new_offsets.push(current_offset);
}
}
- (UInt32Array::from(values), new_offsets)
+
+ Ok((PrimitiveArray::<OffsetType>::from(values), new_offsets))
}
/// Creates a new SIMD mask, i.e. `packed_simd::m32x16` or similar. that indicates if the
@@ -285,31 +298,56 @@ mod tests {
);
}
- #[test]
- fn test_take_value_index_from_list() {
- let value_data = Int32Array::from((0..10).collect::<Vec<i32>>()).data();
- let value_offsets = Buffer::from(&[0, 2, 5, 10].to_byte_slice());
- let list_data_type = DataType::List(Box::new(DataType::Int32));
+ fn build_list<P, S>(
+ list_data_type: DataType,
+ values: PrimitiveArray<P>,
+ offsets: Vec<S>,
+ ) -> ArrayRef
+ where
+ P: ArrowPrimitiveType,
+ S: OffsetSizeTrait,
+ {
+ let value_data = values.data();
+ let value_offsets = Buffer::from(&offsets[..].to_byte_slice());
let list_data = ArrayData::builder(list_data_type)
- .len(3)
+ .len(offsets.len() - 1)
.add_buffer(value_offsets)
.add_child_data(value_data)
.build();
- let array = Arc::new(ListArray::from(list_data)) as ArrayRef;
- let index = UInt32Array::from(vec![2, 0]);
- let (indexed, offsets) = take_value_indices_from_list(&array, &index);
- assert_eq!(vec![0, 5, 7], offsets);
- let data = UInt32Array::from(vec![
- Some(5),
- Some(6),
- Some(7),
- Some(8),
- Some(9),
- Some(0),
- Some(1),
- ])
- .data();
- assert_eq!(data, indexed.data());
+ let array = Arc::new(GenericListArray::<S>::from(list_data)) as ArrayRef;
+ array
+ }
+
+ #[test]
+ fn test_take_value_index_from_list() {
+ let list = build_list(
+ DataType::List(Box::new(DataType::Int32)),
+ Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
+ vec![0i32, 2i32, 5i32, 10i32],
+ );
+ let indices = UInt32Array::from(vec![2, 0]);
+
+ let (indexed, offsets) =
+ take_value_indices_from_list::<_, Int32Type>(&list, &indices).unwrap();
+
+ assert_eq!(indexed, Int32Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
+ assert_eq!(offsets, vec![0, 5, 7]);
+ }
+
+ #[test]
+ fn test_take_value_index_from_large_list() {
+ let list = build_list(
+ DataType::LargeList(Box::new(DataType::Int32)),
+ Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
+ vec![0i64, 2i64, 5i64, 10i64],
+ );
+ let indices = UInt32Array::from(vec![2, 0]);
+
+ let (indexed, offsets) =
+ take_value_indices_from_list::<_, Int64Type>(&list, &indices).unwrap();
+
+ assert_eq!(indexed, Int64Array::from(vec![5, 6, 7, 8, 9, 0, 1]));
+ assert_eq!(offsets, vec![0, 5, 7]);
}
#[test]