You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/02/09 21:03:23 UTC

[arrow-rs] branch master updated: Specialized filter kernels (#1248)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new c064d53  Specialized filter kernels (#1248)
c064d53 is described below

commit c064d53340bd6811b434311434f668a065017d80
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Wed Feb 9 21:03:15 2022 +0000

    Specialized filter kernels (#1248)
    
    * Add specialized primitive filter kernels
    
    * Filter context
    
    * Optimize null buffer construction
    
    * Clippy
    
    * Benchmark filter construction
    
    * Review feedback
    
    * Specialized string filter
    
    * Specialized dictionary filter kernel
    
    * Use trusted_len_iter
    
    * Review feedback
    
    * Add fuzz filter test
    
    * Clarify selective vs selectivity confusion
    
    * Revert change to MutableBuffer::from_trusted_len_iter_bool
    
    * Fix filter_bits offset handling
    
    * Review feedback
    
    * Use i64 for chunk offset
    
    * Only optimize filter when filtering multiple columns
    
    * Test truncated filter
    
    * Review feedback
    
    * Add IterationStrategy::None
    
    * Remove selective / selectivity docs confusion
---
 arrow/benches/filter_kernels.rs     | 139 ++++--
 arrow/src/compute/kernels/filter.rs | 900 +++++++++++++++++++++++++++++++++---
 arrow/src/util/bench_util.rs        |  23 +
 3 files changed, 963 insertions(+), 99 deletions(-)

diff --git a/arrow/benches/filter_kernels.rs b/arrow/benches/filter_kernels.rs
index d5ff09c..be6d902 100644
--- a/arrow/benches/filter_kernels.rs
+++ b/arrow/benches/filter_kernels.rs
@@ -18,13 +18,13 @@ extern crate arrow;
 
 use std::sync::Arc;
 
-use arrow::compute::{filter_record_batch, Filter};
+use arrow::compute::{filter_record_batch, FilterBuilder, FilterPredicate};
 use arrow::record_batch::RecordBatch;
 use arrow::util::bench_util::*;
 
 use arrow::array::*;
-use arrow::compute::{build_filter, filter};
-use arrow::datatypes::{Field, Float32Type, Schema, UInt8Type};
+use arrow::compute::filter;
+use arrow::datatypes::{Field, Float32Type, Int32Type, Schema, UInt8Type};
 
 use criterion::{criterion_group, criterion_main, Criterion};
 
@@ -32,8 +32,8 @@ fn bench_filter(data_array: &dyn Array, filter_array: &BooleanArray) {
     criterion::black_box(filter(data_array, filter_array).unwrap());
 }
 
-fn bench_built_filter<'a>(filter: &Filter<'a>, data: &impl Array) {
-    criterion::black_box(filter(data.data()));
+fn bench_built_filter(filter: &FilterPredicate, array: &dyn Array) {
+    criterion::black_box(filter.filter(array).unwrap());
 }
 
 fn add_benchmark(c: &mut Criterion) {
@@ -42,68 +42,145 @@ fn add_benchmark(c: &mut Criterion) {
     let dense_filter_array = create_boolean_array(size, 0.0, 1.0 - 1.0 / 1024.0);
     let sparse_filter_array = create_boolean_array(size, 0.0, 1.0 / 1024.0);
 
-    let filter = build_filter(&filter_array).unwrap();
-    let dense_filter = build_filter(&dense_filter_array).unwrap();
-    let sparse_filter = build_filter(&sparse_filter_array).unwrap();
+    let filter = FilterBuilder::new(&filter_array).optimize().build();
+    let dense_filter = FilterBuilder::new(&dense_filter_array).optimize().build();
+    let sparse_filter = FilterBuilder::new(&sparse_filter_array).optimize().build();
 
     let data_array = create_primitive_array::<UInt8Type>(size, 0.0);
 
-    c.bench_function("filter u8", |b| {
+    c.bench_function("filter optimize (kept 1/2)", |b| {
+        b.iter(|| FilterBuilder::new(&filter_array).optimize().build())
+    });
+
+    c.bench_function("filter optimize high selectivity (kept 1023/1024)", |b| {
+        b.iter(|| FilterBuilder::new(&dense_filter_array).optimize().build())
+    });
+
+    c.bench_function("filter optimize low selectivity (kept 1/1024)", |b| {
+        b.iter(|| FilterBuilder::new(&sparse_filter_array).optimize().build())
+    });
+
+    c.bench_function("filter u8 (kept 1/2)", |b| {
         b.iter(|| bench_filter(&data_array, &filter_array))
     });
-    c.bench_function("filter u8 high selectivity", |b| {
+    c.bench_function("filter u8 high selectivity (kept 1023/1024)", |b| {
         b.iter(|| bench_filter(&data_array, &dense_filter_array))
     });
-    c.bench_function("filter u8 low selectivity", |b| {
+    c.bench_function("filter u8 low selectivity (kept 1/1024)", |b| {
         b.iter(|| bench_filter(&data_array, &sparse_filter_array))
     });
 
-    c.bench_function("filter context u8", |b| {
+    c.bench_function("filter context u8 (kept 1/2)", |b| {
         b.iter(|| bench_built_filter(&filter, &data_array))
     });
-    c.bench_function("filter context u8 high selectivity", |b| {
+    c.bench_function("filter context u8 high selectivity (kept 1023/1024)", |b| {
         b.iter(|| bench_built_filter(&dense_filter, &data_array))
     });
-    c.bench_function("filter context u8 low selectivity", |b| {
+    c.bench_function("filter context u8 low selectivity (kept 1/1024)", |b| {
         b.iter(|| bench_built_filter(&sparse_filter, &data_array))
     });
 
-    let data_array = create_primitive_array::<UInt8Type>(size, 0.5);
-    c.bench_function("filter context u8 w NULLs", |b| {
-        b.iter(|| bench_built_filter(&filter, &data_array))
+    let data_array = create_primitive_array::<Int32Type>(size, 0.0);
+    c.bench_function("filter i32 (kept 1/2)", |b| {
+        b.iter(|| bench_filter(&data_array, &filter_array))
     });
-    c.bench_function("filter context u8 w NULLs high selectivity", |b| {
-        b.iter(|| bench_built_filter(&dense_filter, &data_array))
+    c.bench_function("filter i32 high selectivity (kept 1023/1024)", |b| {
+        b.iter(|| bench_filter(&data_array, &dense_filter_array))
+    });
+    c.bench_function("filter i32 low selectivity (kept 1/1024)", |b| {
+        b.iter(|| bench_filter(&data_array, &sparse_filter_array))
+    });
+
+    c.bench_function("filter context i32 (kept 1/2)", |b| {
+        b.iter(|| bench_built_filter(&filter, &data_array))
     });
-    c.bench_function("filter context u8 w NULLs low selectivity", |b| {
+    c.bench_function(
+        "filter context i32 high selectivity (kept 1023/1024)",
+        |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)),
+    );
+    c.bench_function("filter context i32 low selectivity (kept 1/1024)", |b| {
         b.iter(|| bench_built_filter(&sparse_filter, &data_array))
     });
 
+    let data_array = create_primitive_array::<Int32Type>(size, 0.5);
+    c.bench_function("filter context i32 w NULLs (kept 1/2)", |b| {
+        b.iter(|| bench_built_filter(&filter, &data_array))
+    });
+    c.bench_function(
+        "filter context i32 w NULLs high selectivity (kept 1023/1024)",
+        |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)),
+    );
+    c.bench_function(
+        "filter context i32 w NULLs low selectivity (kept 1/1024)",
+        |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)),
+    );
+
+    let data_array = create_primitive_array::<UInt8Type>(size, 0.5);
+    c.bench_function("filter context u8 w NULLs (kept 1/2)", |b| {
+        b.iter(|| bench_built_filter(&filter, &data_array))
+    });
+    c.bench_function(
+        "filter context u8 w NULLs high selectivity (kept 1023/1024)",
+        |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)),
+    );
+    c.bench_function(
+        "filter context u8 w NULLs low selectivity (kept 1/1024)",
+        |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)),
+    );
+
     let data_array = create_primitive_array::<Float32Type>(size, 0.5);
-    c.bench_function("filter f32", |b| {
+    c.bench_function("filter f32 (kept 1/2)", |b| {
         b.iter(|| bench_filter(&data_array, &filter_array))
     });
-    c.bench_function("filter context f32", |b| {
+    c.bench_function("filter context f32 (kept 1/2)", |b| {
         b.iter(|| bench_built_filter(&filter, &data_array))
     });
-    c.bench_function("filter context f32 high selectivity", |b| {
-        b.iter(|| bench_built_filter(&dense_filter, &data_array))
-    });
-    c.bench_function("filter context f32 low selectivity", |b| {
+    c.bench_function(
+        "filter context f32 high selectivity (kept 1023/1024)",
+        |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)),
+    );
+    c.bench_function("filter context f32 low selectivity (kept 1/1024)", |b| {
         b.iter(|| bench_built_filter(&sparse_filter, &data_array))
     });
 
     let data_array = create_string_array::<i32>(size, 0.5);
-    c.bench_function("filter context string", |b| {
+    c.bench_function("filter context string (kept 1/2)", |b| {
         b.iter(|| bench_built_filter(&filter, &data_array))
     });
-    c.bench_function("filter context string high selectivity", |b| {
-        b.iter(|| bench_built_filter(&dense_filter, &data_array))
-    });
-    c.bench_function("filter context string low selectivity", |b| {
+    c.bench_function(
+        "filter context string high selectivity (kept 1023/1024)",
+        |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)),
+    );
+    c.bench_function("filter context string low selectivity (kept 1/1024)", |b| {
         b.iter(|| bench_built_filter(&sparse_filter, &data_array))
     });
 
+    let data_array = create_string_dict_array::<Int32Type>(size, 0.0);
+    c.bench_function("filter context string dictionary (kept 1/2)", |b| {
+        b.iter(|| bench_built_filter(&filter, &data_array))
+    });
+    c.bench_function(
+        "filter context string dictionary high selectivity (kept 1023/1024)",
+        |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)),
+    );
+    c.bench_function(
+        "filter context string dictionary low selectivity (kept 1/1024)",
+        |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)),
+    );
+
+    let data_array = create_string_dict_array::<Int32Type>(size, 0.5);
+    c.bench_function("filter context string dictionary w NULLs (kept 1/2)", |b| {
+        b.iter(|| bench_built_filter(&filter, &data_array))
+    });
+    c.bench_function(
+        "filter context string dictionary w NULLs high selectivity (kept 1023/1024)",
+        |b| b.iter(|| bench_built_filter(&dense_filter, &data_array)),
+    );
+    c.bench_function(
+        "filter context string dictionary w NULLs low selectivity (kept 1/1024)",
+        |b| b.iter(|| bench_built_filter(&sparse_filter, &data_array)),
+    );
+
     let data_array = create_primitive_array::<Float32Type>(size, 0.0);
 
     let field = Field::new("c1", data_array.data_type().clone(), true);
diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs
index 0418263..e90add4 100644
--- a/arrow/src/compute/kernels/filter.rs
+++ b/arrow/src/compute/kernels/filter.rs
@@ -17,24 +17,60 @@
 
 //! Defines miscellaneous array kernels.
 
+use std::ops::AddAssign;
+use std::sync::Arc;
+
+use num::Zero;
+
+use TimeUnit::*;
+
 use crate::array::*;
-use crate::buffer::buffer_bin_and;
-use crate::datatypes::DataType;
-use crate::error::Result;
+use crate::buffer::{buffer_bin_and, Buffer, MutableBuffer};
+use crate::datatypes::*;
+use crate::error::{ArrowError, Result};
 use crate::record_batch::RecordBatch;
 use crate::util::bit_chunk_iterator::{UnalignedBitChunk, UnalignedBitChunkIterator};
+use crate::util::bit_util;
 
-/// Function that can filter arbitrary arrays
-pub type Filter<'a> = Box<dyn Fn(&ArrayData) -> ArrayData + 'a>;
+/// If the filter selects more than this fraction of rows, use
+/// [`SlicesIterator`] to copy ranges of values. Otherwise iterate
+/// over individual rows using [`IndexIterator`]
+///
+/// Threshold of 0.8 chosen based on <https://dl.acm.org/doi/abs/10.1145/3465998.3466009>
+///
+const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
+
+macro_rules! downcast_filter {
+    ($type: ty, $values: expr, $filter: expr) => {{
+        let values = $values
+            .as_any()
+            .downcast_ref::<PrimitiveArray<$type>>()
+            .expect("Unable to downcast to a primitive array");
+
+        Ok(Arc::new(filter_primitive::<$type>(&values, $filter)))
+    }};
+}
+
+macro_rules! downcast_dict_filter {
+    ($type: ty, $values: expr, $filter: expr) => {{
+        let values = $values
+            .as_any()
+            .downcast_ref::<DictionaryArray<$type>>()
+            .expect("Unable to downcast to a dictionary array");
+        Ok(Arc::new(filter_dict::<$type>(values, $filter)))
+    }};
+}
 
-/// An iterator of `(usize, usize)` each representing an interval `[start,end[` whose
-/// slots of a [BooleanArray] are true. Each interval corresponds to a contiguous region of memory to be
-/// "taken" from an array to be filtered.
+/// An iterator of `(usize, usize)` each representing an interval `[start, end]` whose
+/// slots of a [BooleanArray] are true. Each interval corresponds to a contiguous region of memory
+/// to be "taken" from an array to be filtered.
+///
+/// This is only performant for filters that copy across long contiguous runs
 #[derive(Debug)]
 pub struct SlicesIterator<'a> {
     iter: UnalignedBitChunkIterator<'a>,
     len: usize,
-    chunk_end_offset: usize,
+    current_offset: i64,
     current_chunk: u64,
 }
 
@@ -45,13 +81,13 @@ impl<'a> SlicesIterator<'a> {
         let chunk = UnalignedBitChunk::new(values.as_slice(), filter.offset(), len);
         let mut iter = chunk.iter();
 
-        let chunk_end_offset = 64 - chunk.lead_padding();
+        let current_offset = -(chunk.lead_padding() as i64);
         let current_chunk = iter.next().unwrap_or(0);
 
         Self {
             iter,
             len,
-            chunk_end_offset,
+            current_offset,
             current_chunk,
         }
     }
@@ -59,18 +95,18 @@ impl<'a> SlicesIterator<'a> {
     /// Returns `Some((chunk_offset, bit_offset))` for the next chunk that has at
     /// least one bit set, or None if there is no such chunk.
     ///
-    /// Where `chunk_offset` is the bit offset to the current `usize`d chunk
+    /// Where `chunk_offset` is the bit offset to the current `u64` chunk
     /// and `bit_offset` is the offset of the first `1` bit in that chunk
-    fn advance_to_set_bit(&mut self) -> Option<(usize, u32)> {
+    fn advance_to_set_bit(&mut self) -> Option<(i64, u32)> {
         loop {
             if self.current_chunk != 0 {
                 // Find the index of the first 1
                 let bit_pos = self.current_chunk.trailing_zeros();
-                return Some((self.chunk_end_offset, bit_pos));
+                return Some((self.current_offset, bit_pos));
             }
 
             self.current_chunk = self.iter.next()?;
-            self.chunk_end_offset += 64;
+            self.current_offset += 64;
         }
     }
 }
@@ -98,19 +134,19 @@ impl<'a> Iterator for SlicesIterator<'a> {
                 self.current_chunk &= !((1 << end_bit) - 1);
 
                 return Some((
-                    start_chunk + start_bit as usize - 64,
-                    self.chunk_end_offset + end_bit as usize - 64,
+                    (start_chunk + start_bit as i64) as usize,
+                    (self.current_offset + end_bit as i64) as usize,
                 ));
             }
 
             match self.iter.next() {
                 Some(next) => {
                     self.current_chunk = next;
-                    self.chunk_end_offset += 64;
+                    self.current_offset += 64;
                 }
                 None => {
                     return Some((
-                        start_chunk + start_bit as usize - 64,
+                        (start_chunk + start_bit as i64) as usize,
                         std::mem::replace(&mut self.len, 0),
                     ));
                 }
@@ -119,17 +155,83 @@ impl<'a> Iterator for SlicesIterator<'a> {
     }
 }
 
+/// An iterator of `usize` whose index in [`BooleanArray`] is true
+///
+/// This provides the best performance on most predicates, apart from those which keep
+/// large runs and therefore favour [`SlicesIterator`]
+struct IndexIterator<'a> {
+    current_chunk: u64,
+    chunk_offset: i64,
+    remaining: usize,
+    iter: UnalignedBitChunkIterator<'a>,
+}
+
+impl<'a> IndexIterator<'a> {
+    fn new(filter: &'a BooleanArray, len: usize) -> Self {
+        assert_eq!(filter.null_count(), 0);
+        let data = filter.data();
+        let chunks =
+            UnalignedBitChunk::new(&data.buffers()[0], data.offset(), data.len());
+        let mut iter = chunks.iter();
+
+        let current_chunk = iter.next().unwrap_or(0);
+        let chunk_offset = -(chunks.lead_padding() as i64);
+
+        Self {
+            current_chunk,
+            chunk_offset,
+            remaining: len,
+            iter,
+        }
+    }
+}
+
+impl<'a> Iterator for IndexIterator<'a> {
+    type Item = usize;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        while self.remaining != 0 {
+            if self.current_chunk != 0 {
+                let bit_pos = self.current_chunk.trailing_zeros();
+                self.current_chunk ^= 1 << bit_pos;
+                self.remaining -= 1;
+                return Some((self.chunk_offset + bit_pos as i64) as usize);
+            }
+
+            // Must panic if exhausted early as trusted length iterator
+            self.current_chunk = self.iter.next().expect("IndexIterator exhausted early");
+            self.chunk_offset += 64;
+        }
+        None
+    }
+
+    fn size_hint(&self) -> (usize, Option<usize>) {
+        (self.remaining, Some(self.remaining))
+    }
+}
+
+/// Counts the number of set bits in `filter`
 fn filter_count(filter: &BooleanArray) -> usize {
     filter
         .values()
         .count_set_bits_offset(filter.offset(), filter.len())
 }
 
+/// Function that can filter arbitrary arrays
+///
+/// Deprecated: Use [`FilterPredicate`] instead
+#[deprecated]
+pub type Filter<'a> = Box<dyn Fn(&ArrayData) -> ArrayData + 'a>;
+
 /// Returns a prepared function optimized to filter multiple arrays.
 /// Creating this function requires time, but using it is faster than [filter] when the
 /// same filter needs to be applied to multiple arrays (e.g. a multi-column `RecordBatch`).
 /// WARNING: the nulls of `filter` are ignored and the value on its slot is considered.
 /// Therefore, it is considered undefined behavior to pass `filter` with null values.
+///
+/// Deprecated: Use [`FilterBuilder`] instead
+#[deprecated]
+#[allow(deprecated)]
 pub fn build_filter(filter: &BooleanArray) -> Result<Filter> {
     let iter = SlicesIterator::new(filter);
     let filter_count = filter_count(filter);
@@ -185,79 +287,600 @@ pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
 /// # Ok(())
 /// # }
 /// ```
-pub fn filter(array: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef> {
-    if predicate.null_count() > 0 {
-        // this greatly simplifies subsequent filtering code
-        // now we only have a boolean mask to deal with
-        let predicate = prep_null_mask_filter(predicate);
-        return filter(array, &predicate);
+pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef> {
+    let predicate = FilterBuilder::new(predicate).build();
+    filter_array(values, &predicate)
+}
+
+/// Returns a new [RecordBatch] with arrays containing only values matching the filter.
+pub fn filter_record_batch(
+    record_batch: &RecordBatch,
+    predicate: &BooleanArray,
+) -> Result<RecordBatch> {
+    let mut filter_builder = FilterBuilder::new(predicate);
+    if record_batch.num_columns() > 1 {
+        // Only optimize if filtering more than one column
+        filter_builder = filter_builder.optimize();
+    }
+    let filter = filter_builder.build();
+
+    let filtered_arrays = record_batch
+        .columns()
+        .iter()
+        .map(|a| filter_array(a, &filter))
+        .collect::<Result<Vec<_>>>()?;
+
+    RecordBatch::try_new(record_batch.schema(), filtered_arrays)
+}
+
+/// A builder to construct [`FilterPredicate`]
+#[derive(Debug)]
+pub struct FilterBuilder {
+    filter: BooleanArray,
+    count: usize,
+    strategy: IterationStrategy,
+}
+
+impl FilterBuilder {
+    /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`]
+    pub fn new(filter: &BooleanArray) -> Self {
+        let filter = match filter.null_count() {
+            0 => BooleanArray::from(filter.data().clone()),
+            _ => prep_null_mask_filter(filter),
+        };
+
+        let count = filter_count(&filter);
+        let strategy = IterationStrategy::default_strategy(filter.len(), count);
+
+        Self {
+            filter,
+            count,
+            strategy,
+        }
     }
 
-    let filter_count = filter_count(predicate);
+    /// Compute an optimised representation of the provided `filter` mask that can be
+    /// applied to an array more quickly.
+    ///
+    /// Note: There is limited benefit to calling this to then filter a single array
+    /// Note: This will likely have a larger memory footprint than the original mask
+    pub fn optimize(mut self) -> Self {
+        match self.strategy {
+            IterationStrategy::SlicesIterator => {
+                let slices = SlicesIterator::new(&self.filter).collect();
+                self.strategy = IterationStrategy::Slices(slices)
+            }
+            IterationStrategy::IndexIterator => {
+                let indices = IndexIterator::new(&self.filter, self.count).collect();
+                self.strategy = IterationStrategy::Indices(indices)
+            }
+            _ => {}
+        }
+        self
+    }
 
-    match filter_count {
-        0 => {
-            // return empty
-            Ok(new_empty_array(array.data_type()))
+    /// Construct the final `FilterPredicate`
+    pub fn build(self) -> FilterPredicate {
+        FilterPredicate {
+            filter: self.filter,
+            count: self.count,
+            strategy: self.strategy,
         }
-        len if len == array.len() => {
-            // return all
-            let data = array.data().clone();
-            Ok(make_array(data))
+    }
+}
+
+/// The iteration strategy used to evaluate [`FilterPredicate`]
+#[derive(Debug)]
+enum IterationStrategy {
+    /// A lazily evaluated iterator of ranges
+    SlicesIterator,
+    /// A lazily evaluated iterator of indices
+    IndexIterator,
+    /// A precomputed list of indices
+    Indices(Vec<usize>),
+    /// A precomputed array of ranges
+    Slices(Vec<(usize, usize)>),
+    /// Select all rows
+    All,
+    /// Select no rows
+    None,
+}
+
+impl IterationStrategy {
+    /// The default [`IterationStrategy`] for a filter of length `filter_length`
+    /// and selecting `filter_count` rows
+    fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
+        if filter_length == 0 || filter_count == 0 {
+            return IterationStrategy::None;
         }
-        _ => {
-            // actually filter
-            let mut mutable =
-                MutableArrayData::new(vec![array.data_ref()], false, filter_count);
 
-            let iter = SlicesIterator::new(predicate);
-            iter.for_each(|(start, end)| mutable.extend(0, start, end));
+        if filter_count == filter_length {
+            return IterationStrategy::All;
+        }
 
-            let data = mutable.freeze();
-            Ok(make_array(data))
+        // Compute the selectivity of the predicate by dividing the number of true
+        // bits in the predicate by the predicate's total length
+        //
+        // This can then be used as a heuristic for the optimal iteration strategy
+        let selectivity_frac = filter_count as f64 / filter_length as f64;
+        if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
+            return IterationStrategy::SlicesIterator;
         }
+        IterationStrategy::IndexIterator
     }
 }
 
-/// Returns a new [RecordBatch] with arrays containing only values matching the filter.
-pub fn filter_record_batch(
-    record_batch: &RecordBatch,
-    predicate: &BooleanArray,
-) -> Result<RecordBatch> {
-    if predicate.null_count() > 0 {
-        // this greatly simplifies subsequent filtering code
-        // now we only have a boolean mask to deal with
-        let predicate = prep_null_mask_filter(predicate);
-        return filter_record_batch(record_batch, &predicate);
+/// A filtering predicate that can be applied to an [`Array`]
+#[derive(Debug)]
+pub struct FilterPredicate {
+    filter: BooleanArray,
+    count: usize,
+    strategy: IterationStrategy,
+}
+
+impl FilterPredicate {
+    /// Selects rows from `values` based on this [`FilterPredicate`]
+    pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef> {
+        filter_array(values, self)
     }
+}
 
-    let num_columns = record_batch.columns().len();
+fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef> {
+    if predicate.filter.len() > values.len() {
+        return Err(ArrowError::InvalidArgumentError(format!(
+            "Filter predicate of length {} is larger than target array of length {}",
+            predicate.filter.len(),
+            values.len()
+        )));
+    }
 
-    let filtered_arrays = match num_columns {
-        1 => {
-            vec![filter(record_batch.columns()[0].as_ref(), predicate)?]
+    match predicate.strategy {
+        IterationStrategy::None => Ok(new_empty_array(values.data_type())),
+        IterationStrategy::All => Ok(make_array(values.data().slice(0, predicate.count))),
+        // actually filter
+        _ => match values.data_type() {
+            DataType::Boolean => {
+                let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
+                Ok(Arc::new(filter_boolean(values, predicate)))
+            }
+            DataType::Int8 => {
+                downcast_filter!(Int8Type, values, predicate)
+            }
+            DataType::Int16 => {
+                downcast_filter!(Int16Type, values, predicate)
+            }
+            DataType::Int32 => {
+                downcast_filter!(Int32Type, values, predicate)
+            }
+            DataType::Int64 => {
+                downcast_filter!(Int64Type, values, predicate)
+            }
+            DataType::UInt8 => {
+                downcast_filter!(UInt8Type, values, predicate)
+            }
+            DataType::UInt16 => {
+                downcast_filter!(UInt16Type, values, predicate)
+            }
+            DataType::UInt32 => {
+                downcast_filter!(UInt32Type, values, predicate)
+            }
+            DataType::UInt64 => {
+                downcast_filter!(UInt64Type, values, predicate)
+            }
+            DataType::Float32 => {
+                downcast_filter!(Float32Type, values, predicate)
+            }
+            DataType::Float64 => {
+                downcast_filter!(Float64Type, values, predicate)
+            }
+            DataType::Date32 => {
+                downcast_filter!(Date32Type, values, predicate)
+            }
+            DataType::Date64 => {
+                downcast_filter!(Date64Type, values, predicate)
+            }
+            DataType::Time32(Second) => {
+                downcast_filter!(Time32SecondType, values, predicate)
+            }
+            DataType::Time32(Millisecond) => {
+                downcast_filter!(Time32MillisecondType, values, predicate)
+            }
+            DataType::Time64(Microsecond) => {
+                downcast_filter!(Time64MicrosecondType, values, predicate)
+            }
+            DataType::Time64(Nanosecond) => {
+                downcast_filter!(Time64NanosecondType, values, predicate)
+            }
+            DataType::Timestamp(Second, _) => {
+                downcast_filter!(TimestampSecondType, values, predicate)
+            }
+            DataType::Timestamp(Millisecond, _) => {
+                downcast_filter!(TimestampMillisecondType, values, predicate)
+            }
+            DataType::Timestamp(Microsecond, _) => {
+                downcast_filter!(TimestampMicrosecondType, values, predicate)
+            }
+            DataType::Timestamp(Nanosecond, _) => {
+                downcast_filter!(TimestampNanosecondType, values, predicate)
+            }
+            DataType::Interval(IntervalUnit::YearMonth) => {
+                downcast_filter!(IntervalYearMonthType, values, predicate)
+            }
+            DataType::Interval(IntervalUnit::DayTime) => {
+                downcast_filter!(IntervalDayTimeType, values, predicate)
+            }
+            DataType::Interval(IntervalUnit::MonthDayNano) => {
+                downcast_filter!(IntervalMonthDayNanoType, values, predicate)
+            }
+            DataType::Duration(TimeUnit::Second) => {
+                downcast_filter!(DurationSecondType, values, predicate)
+            }
+            DataType::Duration(TimeUnit::Millisecond) => {
+                downcast_filter!(DurationMillisecondType, values, predicate)
+            }
+            DataType::Duration(TimeUnit::Microsecond) => {
+                downcast_filter!(DurationMicrosecondType, values, predicate)
+            }
+            DataType::Duration(TimeUnit::Nanosecond) => {
+                downcast_filter!(DurationNanosecondType, values, predicate)
+            }
+            DataType::Utf8 => {
+                let values = values
+                    .as_any()
+                    .downcast_ref::<GenericStringArray<i32>>()
+                    .unwrap();
+                Ok(Arc::new(filter_string::<i32>(values, predicate)))
+            }
+            DataType::LargeUtf8 => {
+                let values = values
+                    .as_any()
+                    .downcast_ref::<GenericStringArray<i64>>()
+                    .unwrap();
+                Ok(Arc::new(filter_string::<i64>(values, predicate)))
+            }
+            DataType::Dictionary(key_type, _) => match key_type.as_ref() {
+                DataType::Int8 => downcast_dict_filter!(Int8Type, values, predicate),
+                DataType::Int16 => downcast_dict_filter!(Int16Type, values, predicate),
+                DataType::Int32 => downcast_dict_filter!(Int32Type, values, predicate),
+                DataType::Int64 => downcast_dict_filter!(Int64Type, values, predicate),
+                DataType::UInt8 => downcast_dict_filter!(UInt8Type, values, predicate),
+                DataType::UInt16 => downcast_dict_filter!(UInt16Type, values, predicate),
+                DataType::UInt32 => downcast_dict_filter!(UInt32Type, values, predicate),
+                DataType::UInt64 => downcast_dict_filter!(UInt64Type, values, predicate),
+                t => {
+                    unimplemented!("Filter not supported for dictionary key type {:?}", t)
+                }
+            },
+            _ => {
+                // fallback to using MutableArrayData
+                let mut mutable = MutableArrayData::new(
+                    vec![values.data_ref()],
+                    false,
+                    predicate.count,
+                );
+
+                match &predicate.strategy {
+                    IterationStrategy::Slices(slices) => {
+                        slices
+                            .iter()
+                            .for_each(|(start, end)| mutable.extend(0, *start, *end));
+                    }
+                    _ => {
+                        let iter = SlicesIterator::new(&predicate.filter);
+                        iter.for_each(|(start, end)| mutable.extend(0, start, end));
+                    }
+                }
+
+                let data = mutable.freeze();
+                Ok(make_array(data))
+            }
+        },
+    }
+}
+
+/// Computes a new null mask for `data` based on `predicate`
+///
+/// If the predicate selected no null-rows, returns `None`, otherwise returns
+/// `Some((null_count, null_buffer))` where `null_count` is the number of nulls
+/// in the filtered output, and `null_buffer` is the filtered null buffer
+///
+fn filter_null_mask(
+    data: &ArrayData,
+    predicate: &FilterPredicate,
+) -> Option<(usize, Buffer)> {
+    if data.null_count() == 0 {
+        return None;
+    }
+
+    let nulls = filter_bits(data.null_buffer()?, data.offset(), predicate);
+    // The filtered `nulls` has a length of `predicate.count` bits and
+    // therefore the null count is this minus the number of valid bits
+    let null_count = predicate.count - nulls.count_set_bits();
+
+    if null_count == 0 {
+        return None;
+    }
+
+    Some((null_count, nulls))
+}
+
+/// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset`
+fn filter_bits(buffer: &Buffer, offset: usize, predicate: &FilterPredicate) -> Buffer {
+    let src = buffer.as_slice();
+
+    match &predicate.strategy {
+        IterationStrategy::IndexIterator => {
+            let bits = IndexIterator::new(&predicate.filter, predicate.count)
+                .map(|src_idx| bit_util::get_bit(src, src_idx + offset));
+
+            // SAFETY: `IndexIterator` reports its size correctly
+            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
         }
-        _ => {
-            let filter = build_filter(predicate)?;
-            record_batch
-                .columns()
+        IterationStrategy::Indices(indices) => {
+            let bits = indices
                 .iter()
-                .map(|a| make_array(filter(a.data())))
-                .collect()
+                .map(|src_idx| bit_util::get_bit(src, *src_idx + offset));
+
+            // SAFETY: `Vec::iter()` reports its size correctly
+            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
+        }
+        IterationStrategy::SlicesIterator => {
+            let mut builder =
+                BooleanBufferBuilder::new(bit_util::ceil(predicate.count, 8));
+            for (start, end) in SlicesIterator::new(&predicate.filter) {
+                builder.append_packed_range(start + offset..end + offset, src)
+            }
+            builder.finish()
+        }
+        IterationStrategy::Slices(slices) => {
+            let mut builder =
+                BooleanBufferBuilder::new(bit_util::ceil(predicate.count, 8));
+            for (start, end) in slices {
+                builder.append_packed_range(*start + offset..*end + offset, src)
+            }
+            builder.finish()
+        }
+        IterationStrategy::All | IterationStrategy::None => unreachable!(),
+    }
+}
+
+/// `filter` implementation for boolean buffers
+fn filter_boolean(values: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
+    let data = values.data();
+    assert_eq!(data.buffers().len(), 1);
+    assert_eq!(data.child_data().len(), 0);
+
+    let values = filter_bits(&data.buffers()[0], data.offset(), predicate);
+
+    let mut builder = ArrayDataBuilder::new(DataType::Boolean)
+        .len(predicate.count)
+        .add_buffer(values);
+
+    if let Some((null_count, nulls)) = filter_null_mask(data, predicate) {
+        builder = builder.null_count(null_count).null_bit_buffer(nulls);
+    }
+
+    let data = unsafe { builder.build_unchecked() };
+    BooleanArray::from(data)
+}
+
+/// `filter` implementation for primitive arrays
+fn filter_primitive<T>(
+    values: &PrimitiveArray<T>,
+    predicate: &FilterPredicate,
+) -> PrimitiveArray<T>
+where
+    T: ArrowPrimitiveType,
+{
+    let data = values.data();
+    assert_eq!(data.buffers().len(), 1);
+    assert_eq!(data.child_data().len(), 0);
+
+    let values = data.buffer::<T::Native>(0);
+    assert!(values.len() >= predicate.filter.len());
+
+    let buffer = match &predicate.strategy {
+        IterationStrategy::SlicesIterator => {
+            let mut buffer =
+                MutableBuffer::with_capacity(predicate.count * T::get_byte_width());
+            for (start, end) in SlicesIterator::new(&predicate.filter) {
+                buffer.extend_from_slice(&values[start..end]);
+            }
+            buffer
+        }
+        IterationStrategy::Slices(slices) => {
+            let mut buffer =
+                MutableBuffer::with_capacity(predicate.count * T::get_byte_width());
+            for (start, end) in slices {
+                buffer.extend_from_slice(&values[*start..*end]);
+            }
+            buffer
+        }
+        IterationStrategy::IndexIterator => {
+            let iter =
+                IndexIterator::new(&predicate.filter, predicate.count).map(|x| values[x]);
+
+            // SAFETY: IndexIterator is trusted length
+            unsafe { MutableBuffer::from_trusted_len_iter(iter) }
+        }
+        IterationStrategy::Indices(indices) => {
+            let iter = indices.iter().map(|x| values[*x]);
+
+            // SAFETY: `Vec::iter` is trusted length
+            unsafe { MutableBuffer::from_trusted_len_iter(iter) }
         }
+        IterationStrategy::All | IterationStrategy::None => unreachable!(),
     };
-    RecordBatch::try_new(record_batch.schema(), filtered_arrays)
+
+    let mut builder = ArrayDataBuilder::new(data.data_type().clone())
+        .len(predicate.count)
+        .add_buffer(buffer.into());
+
+    if let Some((null_count, nulls)) = filter_null_mask(data, predicate) {
+        builder = builder.null_count(null_count).null_bit_buffer(nulls);
+    }
+
+    let data = unsafe { builder.build_unchecked() };
+    PrimitiveArray::from(data)
+}
+
+/// [`FilterString`] is created from a source [`GenericStringArray`] and can be
+/// used to build a new [`GenericStringArray`] by copying values from the source
+///
+/// TODO(raphael): Could this be used for the take kernel as well?
+struct FilterString<'a, OffsetSize> {
+    src_offsets: &'a [OffsetSize],
+    src_values: &'a [u8],
+    dst_offsets: MutableBuffer,
+    dst_values: MutableBuffer,
+    cur_offset: OffsetSize,
+}
+
+impl<'a, OffsetSize> FilterString<'a, OffsetSize>
+where
+    OffsetSize: Zero + AddAssign + StringOffsetSizeTrait,
+{
+    fn new(capacity: usize, array: &'a GenericStringArray<OffsetSize>) -> Self {
+        let num_offsets_bytes = (capacity + 1) * std::mem::size_of::<OffsetSize>();
+        let mut dst_offsets = MutableBuffer::new(num_offsets_bytes);
+        let dst_values = MutableBuffer::new(0);
+        let cur_offset = OffsetSize::zero();
+        dst_offsets.push(cur_offset);
+
+        Self {
+            src_offsets: array.value_offsets(),
+            src_values: &array.data().buffers()[1],
+            dst_offsets,
+            dst_values,
+            cur_offset,
+        }
+    }
+
+    /// Returns the byte offset at `idx`
+    #[inline]
+    fn get_value_offset(&self, idx: usize) -> usize {
+        self.src_offsets[idx].to_usize().expect("illegal offset")
+    }
+
+    /// Returns the start and end of the value at index `idx` along with its length
+    #[inline]
+    fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
+        // These can only fail if `array` contains invalid data
+        let start = self.get_value_offset(idx);
+        let end = self.get_value_offset(idx + 1);
+        let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
+        (start, end, len)
+    }
+
+    /// Extends the in-progress array by the indexes in the provided iterator
+    fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
+        for idx in iter {
+            let (start, end, len) = self.get_value_range(idx);
+            self.cur_offset += len;
+            self.dst_offsets.push(self.cur_offset);
+            self.dst_values
+                .extend_from_slice(&self.src_values[start..end]);
+        }
+    }
+
+    /// Extends the in-progress array by the ranges in the provided iterator
+    fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
+        for (start, end) in iter {
+            // These can only fail if `array` contains invalid data
+            for idx in start..end {
+                let (_, _, len) = self.get_value_range(idx);
+                self.cur_offset += len;
+                self.dst_offsets.push(self.cur_offset); // push_unchecked?
+            }
+
+            let value_start = self.get_value_offset(start);
+            let value_end = self.get_value_offset(end);
+            self.dst_values
+                .extend_from_slice(&self.src_values[value_start..value_end]);
+        }
+    }
+}
+
+/// `filter` implementation for string arrays
+///
+/// Note: NULLs with a non-zero slot length in `array` will have the corresponding
+/// data copied across. This allows handling the null mask separately from the data
+fn filter_string<OffsetSize>(
+    array: &GenericStringArray<OffsetSize>,
+    predicate: &FilterPredicate,
+) -> GenericStringArray<OffsetSize>
+where
+    OffsetSize: Zero + AddAssign + StringOffsetSizeTrait,
+{
+    let data = array.data();
+    assert_eq!(data.buffers().len(), 2);
+    assert_eq!(data.child_data().len(), 0);
+    let mut filter = FilterString::new(predicate.count, array);
+
+    match &predicate.strategy {
+        IterationStrategy::SlicesIterator => {
+            filter.extend_slices(SlicesIterator::new(&predicate.filter))
+        }
+        IterationStrategy::Slices(slices) => filter.extend_slices(slices.iter().cloned()),
+        IterationStrategy::IndexIterator => {
+            filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
+        }
+        IterationStrategy::Indices(indices) => filter.extend_idx(indices.iter().cloned()),
+        IterationStrategy::All | IterationStrategy::None => unreachable!(),
+    }
+
+    let mut builder = ArrayDataBuilder::new(data.data_type().clone())
+        .len(predicate.count)
+        .add_buffer(filter.dst_offsets.into())
+        .add_buffer(filter.dst_values.into());
+
+    if let Some((null_count, nulls)) = filter_null_mask(data, predicate) {
+        builder = builder.null_count(null_count).null_bit_buffer(nulls);
+    }
+
+    let data = unsafe { builder.build_unchecked() };
+    GenericStringArray::from(data)
+}
+
+/// `filter` implementation for dictionaries
+fn filter_dict<T>(
+    array: &DictionaryArray<T>,
+    predicate: &FilterPredicate,
+) -> DictionaryArray<T>
+where
+    T: ArrowPrimitiveType,
+    T::Native: num::Num,
+{
+    let filtered_keys = filter_primitive::<T>(array.keys(), predicate);
+    let filtered_data = filtered_keys.data_ref();
+
+    let data = unsafe {
+        ArrayData::new_unchecked(
+            array.data_type().clone(),
+            filtered_data.len(),
+            Some(filtered_data.null_count()),
+            filtered_data.null_buffer().cloned(),
+            filtered_data.offset(),
+            filtered_data.buffers().to_vec(),
+            array.data().child_data().to_vec(),
+        )
+    };
+
+    DictionaryArray::<T>::from(data)
 }
 
 #[cfg(test)]
 mod tests {
-    use super::*;
+    use rand::distributions::{Alphanumeric, Standard};
+    use rand::prelude::*;
+
     use crate::datatypes::Int64Type;
     use crate::{
         buffer::Buffer,
         datatypes::{DataType, Field},
     };
-    use rand::prelude::*;
+
+    use super::*;
 
     macro_rules! def_temporal_test {
         ($test:ident, $array_type: ident, $data: expr) => {
@@ -682,12 +1305,15 @@ mod tests {
             .build()
             .unwrap();
 
-        let bool_array = BooleanArray::from(data);
+        let filter = BooleanArray::from(data);
 
-        let bits: Vec<_> = SlicesIterator::new(&bool_array)
+        let slice_bits: Vec<_> = SlicesIterator::new(&filter)
             .flat_map(|(start, end)| start..end)
             .collect();
 
+        let count = filter_count(&filter);
+        let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
+
         let expected_bits: Vec<_> = bools
             .iter()
             .skip(offset)
@@ -696,7 +1322,8 @@ mod tests {
             .flat_map(|(idx, v)| v.then(|| idx))
             .collect();
 
-        assert_eq!(bits, expected_bits);
+        assert_eq!(slice_bits, expected_bits);
+        assert_eq!(index_bits, expected_bits);
     }
 
     #[test]
@@ -720,4 +1347,141 @@ mod tests {
         test_slices_fuzz(32, 8, 8);
         test_slices_fuzz(32, 5, 9);
     }
+
+    /// Filters `values` by `predicate` using standard rust iterators
+    fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
+        values
+            .into_iter()
+            .zip(predicate)
+            .filter(|(_, x)| **x)
+            .map(|(a, _)| a)
+            .collect()
+    }
+
+    /// Generates an array of length `len` with `valid_percent` non-null values
+    fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
+    where
+        Standard: Distribution<T>,
+    {
+        let mut rng = thread_rng();
+        (0..len)
+            .map(|_| rng.gen_bool(valid_percent).then(|| rng.gen()))
+            .collect()
+    }
+
+    /// Generates an array of length `len` with `valid_percent` non-null values
+    fn gen_strings(
+        len: usize,
+        valid_percent: f64,
+        str_len_range: std::ops::Range<usize>,
+    ) -> Vec<Option<String>> {
+        let mut rng = thread_rng();
+        (0..len)
+            .map(|_| {
+                rng.gen_bool(valid_percent).then(|| {
+                    let len = rng.gen_range(str_len_range.clone());
+                    (0..len)
+                        .map(|_| char::from(rng.sample(Alphanumeric)))
+                        .collect()
+                })
+            })
+            .collect()
+    }
+
+    /// Returns an iterator that calls `Option::as_deref` on each item
+    fn as_deref<T: std::ops::Deref>(
+        src: &[Option<T>],
+    ) -> impl Iterator<Item = Option<&T::Target>> {
+        src.iter().map(|x| x.as_deref())
+    }
+
+    #[test]
+    fn fuzz_filter() {
+        let mut rng = thread_rng();
+
+        for i in 0..100 {
+            let filter_percent = match i {
+                0..=4 => 1.,
+                5..=10 => 0.,
+                _ => rng.gen_range(0.0..1.0),
+            };
+
+            let valid_percent = rng.gen_range(0.0..1.0);
+
+            let array_len = rng.gen_range(32..256);
+            let array_offset = rng.gen_range(0..10);
+
+            // Construct a predicate
+            let filter_offset = rng.gen_range(0..10);
+            let filter_truncate = rng.gen_range(0..10);
+            let bools: Vec<_> = std::iter::from_fn(|| Some(rng.gen_bool(filter_percent)))
+                .take(array_len + filter_offset - filter_truncate)
+                .collect();
+
+            let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
+
+            // Offset predicate
+            let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
+            let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
+            let bools = &bools[filter_offset..];
+
+            // Test i32
+            let values = gen_primitive(array_len + array_offset, valid_percent);
+            let src = Int32Array::from_iter(values.iter().cloned());
+
+            let src = src.slice(array_offset, array_len);
+            let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
+            let values = &values[array_offset..];
+
+            let filtered = filter(src, predicate).unwrap();
+            let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
+            let actual: Vec<_> = array.iter().collect();
+
+            assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
+
+            // Test string
+            let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
+            let src = StringArray::from_iter(as_deref(&strings));
+
+            let src = src.slice(array_offset, array_len);
+            let src = src.as_any().downcast_ref::<StringArray>().unwrap();
+
+            let filtered = filter(src, predicate).unwrap();
+            let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
+            let actual: Vec<_> = array.iter().collect();
+
+            let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
+            assert_eq!(actual, expected_strings);
+
+            // Test string dictionary
+            let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
+
+            let src = src.slice(array_offset, array_len);
+            let src = src
+                .as_any()
+                .downcast_ref::<DictionaryArray<Int32Type>>()
+                .unwrap();
+
+            let filtered = filter(src, predicate).unwrap();
+
+            let array = filtered
+                .as_any()
+                .downcast_ref::<DictionaryArray<Int32Type>>()
+                .unwrap();
+
+            let values = array
+                .values()
+                .as_any()
+                .downcast_ref::<StringArray>()
+                .unwrap();
+
+            let actual: Vec<_> = array
+                .keys()
+                .iter()
+                .map(|key| key.map(|key| values.value(key as usize)))
+                .collect();
+
+            assert_eq!(actual, expected_strings);
+        }
+    }
 }
diff --git a/arrow/src/util/bench_util.rs b/arrow/src/util/bench_util.rs
index 4034033..eeb906b 100644
--- a/arrow/src/util/bench_util.rs
+++ b/arrow/src/util/bench_util.rs
@@ -110,6 +110,29 @@ pub fn create_string_array<Offset: StringOffsetSizeTrait>(
         .collect()
 }
 
+/// Creates an random (but fixed-seeded) array of a given size and null density
+/// consisting of random 4 character alphanumeric strings
+pub fn create_string_dict_array<K: ArrowDictionaryKeyType>(
+    size: usize,
+    null_density: f32,
+) -> DictionaryArray<K> {
+    let rng = &mut seedable_rng();
+
+    let data: Vec<_> = (0..size)
+        .map(|_| {
+            if rng.gen::<f32>() < null_density {
+                None
+            } else {
+                let value = rng.sample_iter(&Alphanumeric).take(4).collect();
+                let value = String::from_utf8(value).unwrap();
+                Some(value)
+            }
+        })
+        .collect();
+
+    data.iter().map(|x| x.as_deref()).collect()
+}
+
 /// Creates an random (but fixed-seeded) binary array of a given size and null density
 pub fn create_binary_array<Offset: BinaryOffsetSizeTrait>(
     size: usize,