You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/05/26 20:07:25 UTC
[arrow-rs] branch master updated: Fix filter UB and add fast path
(#341)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new e85dc98 Fix filter UB and add fast path (#341)
e85dc98 is described below
commit e85dc984edf2dbd48c7437ca3bed724d2b3ce386
Author: Ritchie Vink <ri...@gmail.com>
AuthorDate: Wed May 26 22:07:19 2021 +0200
Fix filter UB and add fast path (#341)
* fix ub in filter record_batch
* filter fast path
* add all false fast path
* use new_empty_array
* rename filter kernel argument
rename argument: 'filter' to 'predicate'
to reduce name collissions.
---
arrow/src/array/data.rs | 2 +-
arrow/src/compute/kernels/filter.rs | 115 ++++++++++++++++++++++++++----------
2 files changed, 84 insertions(+), 33 deletions(-)
diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs
index 9d5b0ee..172bdaa 100644
--- a/arrow/src/array/data.rs
+++ b/arrow/src/array/data.rs
@@ -412,7 +412,7 @@ impl ArrayData {
}
/// Returns a new empty [ArrayData] valid for `data_type`.
- pub(super) fn new_empty(data_type: &DataType) -> Self {
+ pub fn new_empty(data_type: &DataType) -> Self {
let buffers = new_buffers(data_type, 0);
let [buffer1, buffer2] = buffers;
let buffers = into_buffers(data_type, buffer1, buffer2);
diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs
index 4da07b8..b15692e 100644
--- a/arrow/src/compute/kernels/filter.rs
+++ b/arrow/src/compute/kernels/filter.rs
@@ -197,14 +197,37 @@ pub fn build_filter(filter: &BooleanArray) -> Result<Filter> {
let chunks = iter.collect::<Vec<_>>();
Ok(Box::new(move |array: &ArrayData| {
- let mut mutable = MutableArrayData::new(vec![array], false, filter_count);
- chunks
- .iter()
- .for_each(|(start, end)| mutable.extend(0, *start, *end));
- mutable.freeze()
+ match filter_count {
+ // return all
+ len if len == array.len() => array.clone(),
+ 0 => ArrayData::new_empty(array.data_type()),
+ _ => {
+ let mut mutable = MutableArrayData::new(vec![array], false, filter_count);
+ chunks
+ .iter()
+ .for_each(|(start, end)| mutable.extend(0, *start, *end));
+ mutable.freeze()
+ }
+ }
}))
}
+/// Remove null values by do a bitmask AND operation with null bits and the boolean bits.
+fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
+ let array_data = filter.data_ref();
+ let null_bitmap = array_data.null_buffer().unwrap();
+ let mask = filter.values();
+ let offset = filter.offset();
+
+ let new_mask = buffer_bin_and(mask, offset, null_bitmap, offset, filter.len());
+
+ let array_data = ArrayData::builder(DataType::Boolean)
+ .len(filter.len())
+ .add_buffer(new_mask)
+ .build();
+ BooleanArray::from(array_data)
+}
+
/// Filters an [Array], returning elements matching the filter (i.e. where the values are true).
///
/// # Example
@@ -221,43 +244,49 @@ pub fn build_filter(filter: &BooleanArray) -> Result<Filter> {
/// # Ok(())
/// # }
/// ```
-pub fn filter(array: &Array, filter: &BooleanArray) -> Result<ArrayRef> {
- if filter.null_count() > 0 {
+pub fn filter(array: &Array, predicate: &BooleanArray) -> Result<ArrayRef> {
+ if predicate.null_count() > 0 {
// this greatly simplifies subsequent filtering code
// now we only have a boolean mask to deal with
- let array_data = filter.data_ref();
- let null_bitmap = array_data.null_buffer().unwrap();
- let mask = filter.values();
- let offset = filter.offset();
-
- let new_mask = buffer_bin_and(mask, offset, null_bitmap, offset, filter.len());
-
- let array_data = ArrayData::builder(DataType::Boolean)
- .len(filter.len())
- .add_buffer(new_mask)
- .build();
- let filter = BooleanArray::from(array_data);
- // fully qualified syntax, because we have an argument with the same name
- return crate::compute::kernels::filter::filter(array, &filter);
+ let predicate = prep_null_mask_filter(predicate);
+ return filter(array, &predicate);
}
- let iter = SlicesIterator::new(filter);
-
- let mut mutable =
- MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count);
- iter.for_each(|(start, end)| mutable.extend(0, start, end));
- let data = mutable.freeze();
- Ok(make_array(data))
+ let iter = SlicesIterator::new(predicate);
+ match iter.filter_count {
+ 0 => {
+ // return empty
+ Ok(new_empty_array(array.data_type()))
+ }
+ len if len == array.len() => {
+ // return all
+ let data = array.data().clone();
+ Ok(make_array(data))
+ }
+ _ => {
+ // actually filter
+ let mut mutable =
+ MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count);
+ iter.for_each(|(start, end)| mutable.extend(0, start, end));
+ let data = mutable.freeze();
+ Ok(make_array(data))
+ }
+ }
}
/// Returns a new [RecordBatch] with arrays containing only values matching the filter.
-/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
-/// Therefore, it is considered undefined behavior to pass `filter` with null values.
pub fn filter_record_batch(
record_batch: &RecordBatch,
- filter: &BooleanArray,
+ predicate: &BooleanArray,
) -> Result<RecordBatch> {
- let filter = build_filter(filter)?;
+ if predicate.null_count() > 0 {
+ // this greatly simplifies subsequent filtering code
+ // now we only have a boolean mask to deal with
+ let predicate = prep_null_mask_filter(predicate);
+ return filter_record_batch(record_batch, &predicate);
+ }
+
+ let filter = build_filter(predicate)?;
let filtered_arrays = record_batch
.columns()
.iter()
@@ -625,4 +654,26 @@ mod tests {
assert_eq!(out_arr0, out_arr1);
Ok(())
}
+
+ #[test]
+ fn test_fast_path() -> Result<()> {
+ let a: PrimitiveArray<Int64Type> =
+ PrimitiveArray::from(vec![Some(1), Some(2), None]);
+
+ // all true
+ let mask = BooleanArray::from(vec![true, true, true]);
+ let out = filter(&a, &mask)?;
+ let b = out
+ .as_any()
+ .downcast_ref::<PrimitiveArray<Int64Type>>()
+ .unwrap();
+ assert_eq!(&a, b);
+
+ // all false
+ let mask = BooleanArray::from(vec![false, false, false]);
+ let out = filter(&a, &mask)?;
+ assert_eq!(out.len(), 0);
+ assert_eq!(out.data_type(), &DataType::Int64);
+ Ok(())
+ }
}