You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/05/26 20:07:25 UTC

[arrow-rs] branch master updated: Fix filter UB and add fast path (#341)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new e85dc98  Fix filter UB and add fast path (#341)
e85dc98 is described below

commit e85dc984edf2dbd48c7437ca3bed724d2b3ce386
Author: Ritchie Vink <ri...@gmail.com>
AuthorDate: Wed May 26 22:07:19 2021 +0200

    Fix filter UB and add fast path (#341)
    
    * fix ub in filter record_batch
    
    * filter fast path
    
    * add all false fast path
    
    * use new_empty_array
    
    * rename filter kernel argument
    
    rename argument: 'filter' to 'predicate'
    to reduce name collissions.
---
 arrow/src/array/data.rs             |   2 +-
 arrow/src/compute/kernels/filter.rs | 115 ++++++++++++++++++++++++++----------
 2 files changed, 84 insertions(+), 33 deletions(-)

diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs
index 9d5b0ee..172bdaa 100644
--- a/arrow/src/array/data.rs
+++ b/arrow/src/array/data.rs
@@ -412,7 +412,7 @@ impl ArrayData {
     }
 
     /// Returns a new empty [ArrayData] valid for `data_type`.
-    pub(super) fn new_empty(data_type: &DataType) -> Self {
+    pub fn new_empty(data_type: &DataType) -> Self {
         let buffers = new_buffers(data_type, 0);
         let [buffer1, buffer2] = buffers;
         let buffers = into_buffers(data_type, buffer1, buffer2);
diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs
index 4da07b8..b15692e 100644
--- a/arrow/src/compute/kernels/filter.rs
+++ b/arrow/src/compute/kernels/filter.rs
@@ -197,14 +197,37 @@ pub fn build_filter(filter: &BooleanArray) -> Result<Filter> {
     let chunks = iter.collect::<Vec<_>>();
 
     Ok(Box::new(move |array: &ArrayData| {
-        let mut mutable = MutableArrayData::new(vec![array], false, filter_count);
-        chunks
-            .iter()
-            .for_each(|(start, end)| mutable.extend(0, *start, *end));
-        mutable.freeze()
+        match filter_count {
+            // return all
+            len if len == array.len() => array.clone(),
+            0 => ArrayData::new_empty(array.data_type()),
+            _ => {
+                let mut mutable = MutableArrayData::new(vec![array], false, filter_count);
+                chunks
+                    .iter()
+                    .for_each(|(start, end)| mutable.extend(0, *start, *end));
+                mutable.freeze()
+            }
+        }
     }))
 }
 
+/// Remove null values by do a bitmask AND operation with null bits and the boolean bits.
+fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
+    let array_data = filter.data_ref();
+    let null_bitmap = array_data.null_buffer().unwrap();
+    let mask = filter.values();
+    let offset = filter.offset();
+
+    let new_mask = buffer_bin_and(mask, offset, null_bitmap, offset, filter.len());
+
+    let array_data = ArrayData::builder(DataType::Boolean)
+        .len(filter.len())
+        .add_buffer(new_mask)
+        .build();
+    BooleanArray::from(array_data)
+}
+
 /// Filters an [Array], returning elements matching the filter (i.e. where the values are true).
 ///
 /// # Example
@@ -221,43 +244,49 @@ pub fn build_filter(filter: &BooleanArray) -> Result<Filter> {
 /// # Ok(())
 /// # }
 /// ```
-pub fn filter(array: &Array, filter: &BooleanArray) -> Result<ArrayRef> {
-    if filter.null_count() > 0 {
+pub fn filter(array: &Array, predicate: &BooleanArray) -> Result<ArrayRef> {
+    if predicate.null_count() > 0 {
         // this greatly simplifies subsequent filtering code
         // now we only have a boolean mask to deal with
-        let array_data = filter.data_ref();
-        let null_bitmap = array_data.null_buffer().unwrap();
-        let mask = filter.values();
-        let offset = filter.offset();
-
-        let new_mask = buffer_bin_and(mask, offset, null_bitmap, offset, filter.len());
-
-        let array_data = ArrayData::builder(DataType::Boolean)
-            .len(filter.len())
-            .add_buffer(new_mask)
-            .build();
-        let filter = BooleanArray::from(array_data);
-        // fully qualified syntax, because we have an argument with the same name
-        return crate::compute::kernels::filter::filter(array, &filter);
+        let predicate = prep_null_mask_filter(predicate);
+        return filter(array, &predicate);
     }
 
-    let iter = SlicesIterator::new(filter);
-
-    let mut mutable =
-        MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count);
-    iter.for_each(|(start, end)| mutable.extend(0, start, end));
-    let data = mutable.freeze();
-    Ok(make_array(data))
+    let iter = SlicesIterator::new(predicate);
+    match iter.filter_count {
+        0 => {
+            // return empty
+            Ok(new_empty_array(array.data_type()))
+        }
+        len if len == array.len() => {
+            // return all
+            let data = array.data().clone();
+            Ok(make_array(data))
+        }
+        _ => {
+            // actually filter
+            let mut mutable =
+                MutableArrayData::new(vec![array.data_ref()], false, iter.filter_count);
+            iter.for_each(|(start, end)| mutable.extend(0, start, end));
+            let data = mutable.freeze();
+            Ok(make_array(data))
+        }
+    }
 }
 
 /// Returns a new [RecordBatch] with arrays containing only values matching the filter.
-/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
-/// Therefore, it is considered undefined behavior to pass `filter` with null values.
 pub fn filter_record_batch(
     record_batch: &RecordBatch,
-    filter: &BooleanArray,
+    predicate: &BooleanArray,
 ) -> Result<RecordBatch> {
-    let filter = build_filter(filter)?;
+    if predicate.null_count() > 0 {
+        // this greatly simplifies subsequent filtering code
+        // now we only have a boolean mask to deal with
+        let predicate = prep_null_mask_filter(predicate);
+        return filter_record_batch(record_batch, &predicate);
+    }
+
+    let filter = build_filter(predicate)?;
     let filtered_arrays = record_batch
         .columns()
         .iter()
@@ -625,4 +654,26 @@ mod tests {
         assert_eq!(out_arr0, out_arr1);
         Ok(())
     }
+
+    #[test]
+    fn test_fast_path() -> Result<()> {
+        let a: PrimitiveArray<Int64Type> =
+            PrimitiveArray::from(vec![Some(1), Some(2), None]);
+
+        // all true
+        let mask = BooleanArray::from(vec![true, true, true]);
+        let out = filter(&a, &mask)?;
+        let b = out
+            .as_any()
+            .downcast_ref::<PrimitiveArray<Int64Type>>()
+            .unwrap();
+        assert_eq!(&a, b);
+
+        // all false
+        let mask = BooleanArray::from(vec![false, false, false]);
+        let out = filter(&a, &mask)?;
+        assert_eq!(out.len(), 0);
+        assert_eq!(out.data_type(), &DataType::Int64);
+        Ok(())
+    }
 }