You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/10/21 19:27:00 UTC
[arrow-rs] branch master updated: Add specialized interleave implementation for primitives (#2898)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 9e5e47716 Add specialized interleave implementation for primitives (#2898)
9e5e47716 is described below
commit 9e5e47716898ade5e6ffcff1f77551f82d55a1b8
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Sat Oct 22 08:26:56 2022 +1300
Add specialized interleave implementation for primitives (#2898)
---
arrow-select/src/interleave.rs | 63 +++++++++++++++++++++++--
arrow/Cargo.toml | 5 ++
arrow/benches/interleave_kernels.rs | 91 +++++++++++++++++++++++++++++++++++++
3 files changed, 155 insertions(+), 4 deletions(-)
diff --git a/arrow-select/src/interleave.rs b/arrow-select/src/interleave.rs
index 537075f1f..29f75894d 100644
--- a/arrow-select/src/interleave.rs
+++ b/arrow-select/src/interleave.rs
@@ -15,9 +15,22 @@
// specific language governing permissions and limitations
// under the License.
-use arrow_array::{make_array, new_empty_array, Array, ArrayRef};
+use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder};
+use arrow_array::cast::as_primitive_array;
+use arrow_array::{
+ downcast_primitive, make_array, new_empty_array, Array, ArrayRef, ArrowPrimitiveType,
+ PrimitiveArray,
+};
use arrow_data::transform::MutableArrayData;
-use arrow_schema::ArrowError;
+use arrow_data::ArrayDataBuilder;
+use arrow_schema::{ArrowError, DataType};
+use std::sync::Arc;
+
+macro_rules! primitive_helper {
+ ($t:ty, $values:ident, $indices:ident, $data_type:ident) => {
+ interleave_primitive::<$t>($values, $indices, $data_type)
+ };
+}
///
/// Takes elements by index from a list of [`Array`], creating a new [`Array`] from those values.
@@ -70,9 +83,51 @@ pub fn interleave(
return Ok(new_empty_array(data_type));
}
- // TODO: Add specialized implementations (#2864)
+ downcast_primitive! {
+ data_type => (primitive_helper, values, indices, data_type),
+ _ => interleave_fallback(values, indices)
+ }
+}
+
+fn interleave_primitive<T: ArrowPrimitiveType>(
+ values: &[&dyn Array],
+ indices: &[(usize, usize)],
+ data_type: &DataType,
+) -> Result<ArrayRef, ArrowError> {
+ let mut has_nulls = false;
+ let cast: Vec<_> = values
+ .iter()
+ .map(|x| {
+ has_nulls = has_nulls || x.null_count() != 0;
+ as_primitive_array::<T>(*x)
+ })
+ .collect();
+
+ let mut values = BufferBuilder::<T::Native>::new(indices.len());
+ for (a, b) in indices {
+ let v = cast[*a].value(*b);
+ values.append(v)
+ }
+
+ let mut null_count = 0;
+ let nulls = has_nulls.then(|| {
+ let mut builder = BooleanBufferBuilder::new(indices.len());
+ for (a, b) in indices {
+ let v = cast[*a].is_valid(*b);
+ null_count += !v as usize;
+ builder.append(v)
+ }
+ builder.finish()
+ });
+
+ let builder = ArrayDataBuilder::new(data_type.clone())
+ .len(indices.len())
+ .add_buffer(values.finish())
+ .null_bit_buffer(nulls)
+ .null_count(null_count);
- interleave_fallback(values, indices)
+ let data = unsafe { builder.build_unchecked() };
+ Ok(Arc::new(PrimitiveArray::<T>::from(data)))
}
/// Fallback implementation of interleave using [`MutableArrayData`]
diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml
index d066cbce9..7a933360c 100644
--- a/arrow/Cargo.toml
+++ b/arrow/Cargo.toml
@@ -176,6 +176,11 @@ name = "take_kernels"
harness = false
required-features = ["test_utils"]
+[[bench]]
+name = "interleave_kernels"
+harness = false
+required-features = ["test_utils"]
+
[[bench]]
name = "length_kernel"
harness = false
diff --git a/arrow/benches/interleave_kernels.rs b/arrow/benches/interleave_kernels.rs
new file mode 100644
index 000000000..6cf56eb98
--- /dev/null
+++ b/arrow/benches/interleave_kernels.rs
@@ -0,0 +1,91 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[macro_use]
+extern crate criterion;
+
+use criterion::Criterion;
+use std::ops::Range;
+
+use rand::Rng;
+
+extern crate arrow;
+
+use arrow::datatypes::*;
+use arrow::util::test_util::seedable_rng;
+use arrow::{array::*, util::bench_util::*};
+use arrow_select::interleave::interleave;
+
+fn do_bench(
+ c: &mut Criterion,
+ prefix: &str,
+ len: usize,
+ base: &dyn Array,
+ slices: &[Range<usize>],
+) {
+ let mut rng = seedable_rng();
+
+ let arrays: Vec<_> = slices
+ .iter()
+ .map(|r| base.slice(r.start, r.end - r.start))
+ .collect();
+ let values: Vec<_> = arrays.iter().map(|x| x.as_ref()).collect();
+
+ let indices: Vec<_> = (0..len)
+ .map(|_| {
+ let array_idx = rng.gen_range(0..values.len());
+ let value_idx = rng.gen_range(0..values[array_idx].len());
+ (array_idx, value_idx)
+ })
+ .collect();
+
+ c.bench_function(
+ &format!("interleave {} {} {:?}", prefix, len, slices),
+ |b| b.iter(|| criterion::black_box(interleave(&values, &indices).unwrap())),
+ );
+}
+
+fn add_benchmark(c: &mut Criterion) {
+ let a = create_primitive_array::<Int32Type>(1024, 0.);
+
+ do_bench(c, "i32(0.0)", 100, &a, &[0..100, 100..230, 450..1000]);
+ do_bench(c, "i32(0.0)", 400, &a, &[0..100, 100..230, 450..1000]);
+ do_bench(c, "i32(0.0)", 1024, &a, &[0..100, 100..230, 450..1000]);
+ do_bench(
+ c,
+ "i32(0.0)",
+ 1024,
+ &a,
+ &[0..100, 100..230, 450..1000, 0..1000],
+ );
+
+ let a = create_primitive_array::<Int32Type>(1024, 0.5);
+
+ do_bench(c, "i32(0.5)", 100, &a, &[0..100, 100..230, 450..1000]);
+ do_bench(c, "i32(0.5)", 400, &a, &[0..100, 100..230, 450..1000]);
+ do_bench(c, "i32(0.5)", 1024, &a, &[0..100, 100..230, 450..1000]);
+ do_bench(
+ c,
+ "i32(0.5)",
+ 1024,
+ &a,
+ &[0..100, 100..230, 450..1000, 0..1000],
+ );
+}
+
+criterion_group!(benches, add_benchmark);
+criterion_main!(benches);