You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/20 21:53:13 UTC

[arrow-datafusion] branch master updated: Add support for linear range calculation in WINDOW functions (#4989)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new b71cae8aa Add support for linear range calculation in WINDOW functions (#4989)
b71cae8aa is described below

commit b71cae8aa556369bc5ee72b063ed1fc5a81192f1
Author: Mustafa Akur <10...@users.noreply.github.com>
AuthorDate: Sat Jan 21 00:53:06 2023 +0300

    Add support for linear range calculation in WINDOW functions (#4989)
    
    * add naive linear search
    
    * Add last range to decrease search size
    
    * minor changes
    
    * add low, high arguments
    
    * Go back to old API, improve comments, refactors
    
    * use util function
    
    Co-authored-by: Mehmet Ozan Kabak <oz...@gmail.com>
---
 datafusion/common/src/lib.rs                       |   2 +-
 datafusion/common/src/{bisect.rs => utils.rs}      | 119 +++++++++++++++++----
 datafusion/physical-expr/src/window/aggregate.rs   |   9 +-
 datafusion/physical-expr/src/window/built_in.rs    |   9 +-
 datafusion/physical-expr/src/window/nth_value.rs   |   6 +-
 .../src/window/partition_evaluator.rs              |   2 +-
 .../physical-expr/src/window/sliding_aggregate.rs  |   1 +
 .../physical-expr/src/window/window_frame_state.rs |  44 ++++++--
 8 files changed, 149 insertions(+), 43 deletions(-)

diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs
index 26fc58b8c..636feb21a 100644
--- a/datafusion/common/src/lib.rs
+++ b/datafusion/common/src/lib.rs
@@ -15,7 +15,6 @@
 // specific language governing permissions and limitations
 // under the License.
 
-pub mod bisect;
 pub mod cast;
 mod column;
 pub mod config;
@@ -30,6 +29,7 @@ pub mod scalar;
 pub mod stats;
 mod table_reference;
 pub mod test_util;
+pub mod utils;
 
 use arrow::compute::SortOptions;
 pub use column::Column;
diff --git a/datafusion/common/src/bisect.rs b/datafusion/common/src/utils.rs
similarity index 63%
rename from datafusion/common/src/bisect.rs
rename to datafusion/common/src/utils.rs
index 796598be2..3c0730153 100644
--- a/datafusion/common/src/bisect.rs
+++ b/datafusion/common/src/utils.rs
@@ -22,8 +22,16 @@ use arrow::array::ArrayRef;
 use arrow::compute::SortOptions;
 use std::cmp::Ordering;
 
+/// Given column vectors, returns row at `idx`.
+pub fn get_row_at_idx(columns: &[ArrayRef], idx: usize) -> Result<Vec<ScalarValue>> {
+    columns
+        .iter()
+        .map(|arr| ScalarValue::try_from_array(arr, idx))
+        .collect()
+}
+
 /// This function compares two tuples depending on the given sort options.
-fn compare(
+pub fn compare_rows(
     x: &[ScalarValue],
     y: &[ScalarValue],
     sort_options: &[SortOptions],
@@ -52,9 +60,10 @@ fn compare(
     Ok(Ordering::Equal)
 }
 
-/// This function implements both bisect_left and bisect_right, having the same
-/// semantics with the Python Standard Library. To use bisect_left, supply true
-/// as the template argument. To use bisect_right, supply false as the template argument.
+/// This function searches for a tuple of given values (`target`) among the given
+/// rows (`item_columns`) using the bisection algorithm. It assumes that `item_columns`
+/// is sorted according to `sort_options` and returns the insertion index of `target`.
+/// Template argument `SIDE` being `true`/`false` means left/right insertion.
 pub fn bisect<const SIDE: bool>(
     item_columns: &[ArrayRef],
     target: &[ScalarValue],
@@ -68,16 +77,18 @@ pub fn bisect<const SIDE: bool>(
         })?
         .len();
     let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| {
-        let cmp = compare(current, target, sort_options)?;
+        let cmp = compare_rows(current, target, sort_options)?;
         Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() })
     };
     find_bisect_point(item_columns, target, compare_fn, low, high)
 }
 
-/// This function searches for a tuple of target values among the given rows using the bisection algorithm.
-/// The boolean-valued function `compare_fn` specifies whether we bisect on the left (with return value `false`),
-/// or on the right (with return value `true`) as we compare the target value with the current value as we iteratively
-/// bisect the input.
+/// This function searches for a tuple of given values (`target`) among a slice of
+/// the given rows (`item_columns`) using the bisection algorithm. The slice starts
+/// at the index `low` and ends at the index `high`. The boolean-valued function
+/// `compare_fn` specifies whether we bisect on the left (by returning `false`),
+/// or on the right (by returning `true`) when we compare the target value with
+/// the current value as we iteratively bisect the input.
 pub fn find_bisect_point<F>(
     item_columns: &[ArrayRef],
     target: &[ScalarValue],
@@ -90,10 +101,7 @@ where
 {
     while low < high {
         let mid = ((high - low) / 2) + low;
-        let val = item_columns
-            .iter()
-            .map(|arr| ScalarValue::try_from_array(arr, mid))
-            .collect::<Result<Vec<ScalarValue>>>()?;
+        let val = get_row_at_idx(item_columns, mid)?;
         if compare_fn(&val, target)? {
             low = mid + 1;
         } else {
@@ -103,6 +111,53 @@ where
     Ok(low)
 }
 
+/// This function searches for a tuple of given values (`target`) among the given
+/// rows (`item_columns`) via a linear scan. It assumes that `item_columns` is sorted
+/// according to `sort_options` and returns the insertion index of `target`.
+/// Template argument `SIDE` being `true`/`false` means left/right insertion.
+pub fn linear_search<const SIDE: bool>(
+    item_columns: &[ArrayRef],
+    target: &[ScalarValue],
+    sort_options: &[SortOptions],
+) -> Result<usize> {
+    let low: usize = 0;
+    let high: usize = item_columns
+        .get(0)
+        .ok_or_else(|| {
+            DataFusionError::Internal("Column array shouldn't be empty".to_string())
+        })?
+        .len();
+    let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| {
+        let cmp = compare_rows(current, target, sort_options)?;
+        Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() })
+    };
+    search_in_slice(item_columns, target, compare_fn, low, high)
+}
+
+/// This function searches for a tuple of given values (`target`) among a slice of
+/// the given rows (`item_columns`) via a linear scan. The slice starts at the index
+/// `low` and ends at the index `high`. The boolean-valued function `compare_fn`
+/// specifies the stopping criterion.
+pub fn search_in_slice<F>(
+    item_columns: &[ArrayRef],
+    target: &[ScalarValue],
+    compare_fn: F,
+    mut low: usize,
+    high: usize,
+) -> Result<usize>
+where
+    F: Fn(&[ScalarValue], &[ScalarValue]) -> Result<bool>,
+{
+    while low < high {
+        let val = get_row_at_idx(item_columns, low)?;
+        if !compare_fn(&val, target)? {
+            break;
+        }
+        low += 1;
+    }
+    Ok(low)
+}
+
 #[cfg(test)]
 mod tests {
     use arrow::array::Float64Array;
@@ -115,7 +170,7 @@ mod tests {
     use super::*;
 
     #[test]
-    fn test_bisect_left_and_right() {
+    fn test_bisect_linear_left_and_right() -> Result<()> {
         let arrays: Vec<ArrayRef> = vec![
             Arc::new(Float64Array::from_slice([5.0, 7.0, 8.0, 9., 10.])),
             Arc::new(Float64Array::from_slice([2.0, 3.0, 3.0, 4.0, 5.0])),
@@ -146,10 +201,15 @@ mod tests {
                 nulls_first: true,
             },
         ];
-        let res: usize = bisect::<true>(&arrays, &search_tuple, &ords).unwrap();
+        let res = bisect::<true>(&arrays, &search_tuple, &ords)?;
         assert_eq!(res, 2);
-        let res: usize = bisect::<false>(&arrays, &search_tuple, &ords).unwrap();
+        let res = bisect::<false>(&arrays, &search_tuple, &ords)?;
         assert_eq!(res, 3);
+        let res = linear_search::<true>(&arrays, &search_tuple, &ords)?;
+        assert_eq!(res, 2);
+        let res = linear_search::<false>(&arrays, &search_tuple, &ords)?;
+        assert_eq!(res, 3);
+        Ok(())
     }
 
     #[test]
@@ -186,7 +246,7 @@ mod tests {
     }
 
     #[test]
-    fn test_bisect_left_and_right_diff_sort() {
+    fn test_bisect_linear_left_and_right_diff_sort() -> Result<()> {
         // Descending, left
         let arrays: Vec<ArrayRef> = vec![Arc::new(Float64Array::from_slice([
             4.0, 3.0, 2.0, 1.0, 0.0,
@@ -196,7 +256,9 @@ mod tests {
             descending: true,
             nulls_first: true,
         }];
-        let res: usize = bisect::<true>(&arrays, &search_tuple, &ords).unwrap();
+        let res = bisect::<true>(&arrays, &search_tuple, &ords)?;
+        assert_eq!(res, 0);
+        let res = linear_search::<true>(&arrays, &search_tuple, &ords)?;
         assert_eq!(res, 0);
 
         // Descending, right
@@ -208,7 +270,9 @@ mod tests {
             descending: true,
             nulls_first: true,
         }];
-        let res: usize = bisect::<false>(&arrays, &search_tuple, &ords).unwrap();
+        let res = bisect::<false>(&arrays, &search_tuple, &ords)?;
+        assert_eq!(res, 1);
+        let res = linear_search::<false>(&arrays, &search_tuple, &ords)?;
         assert_eq!(res, 1);
 
         // Ascending, left
@@ -219,7 +283,9 @@ mod tests {
             descending: false,
             nulls_first: true,
         }];
-        let res: usize = bisect::<true>(&arrays, &search_tuple, &ords).unwrap();
+        let res = bisect::<true>(&arrays, &search_tuple, &ords)?;
+        assert_eq!(res, 1);
+        let res = linear_search::<true>(&arrays, &search_tuple, &ords)?;
         assert_eq!(res, 1);
 
         // Ascending, right
@@ -230,7 +296,9 @@ mod tests {
             descending: false,
             nulls_first: true,
         }];
-        let res: usize = bisect::<false>(&arrays, &search_tuple, &ords).unwrap();
+        let res = bisect::<false>(&arrays, &search_tuple, &ords)?;
+        assert_eq!(res, 2);
+        let res = linear_search::<false>(&arrays, &search_tuple, &ords)?;
         assert_eq!(res, 2);
 
         let arrays: Vec<ArrayRef> = vec![
@@ -251,10 +319,15 @@ mod tests {
                 nulls_first: true,
             },
         ];
-        let res: usize = bisect::<false>(&arrays, &search_tuple, &ords).unwrap();
+        let res = bisect::<false>(&arrays, &search_tuple, &ords)?;
+        assert_eq!(res, 3);
+        let res = linear_search::<false>(&arrays, &search_tuple, &ords)?;
         assert_eq!(res, 3);
 
-        let res: usize = bisect::<true>(&arrays, &search_tuple, &ords).unwrap();
+        let res = bisect::<true>(&arrays, &search_tuple, &ords)?;
+        assert_eq!(res, 2);
+        let res = linear_search::<true>(&arrays, &search_tuple, &ords)?;
         assert_eq!(res, 2);
+        Ok(())
     }
 }
diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs
index df61e7cc8..fe725f2d7 100644
--- a/datafusion/physical-expr/src/window/aggregate.rs
+++ b/datafusion/physical-expr/src/window/aggregate.rs
@@ -106,8 +106,13 @@ impl WindowExpr for AggregateWindowExpr {
         // We iterate on each row to perform a running calculation.
         // First, cur_range is calculated, then it is compared with last_range.
         for i in 0..length {
-            let cur_range =
-                window_frame_ctx.calculate_range(&order_bys, &sort_options, length, i)?;
+            let cur_range = window_frame_ctx.calculate_range(
+                &order_bys,
+                &sort_options,
+                length,
+                i,
+                &last_range,
+            )?;
             let value = if cur_range.end == cur_range.start {
                 // We produce None if the window is empty.
                 ScalarValue::try_from(self.aggregate.field()?.data_type())?
diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs
index f0484b790..b73e2b8de 100644
--- a/datafusion/physical-expr/src/window/built_in.rs
+++ b/datafusion/physical-expr/src/window/built_in.rs
@@ -34,6 +34,7 @@ use datafusion_common::ScalarValue;
 use datafusion_common::{DataFusionError, Result};
 use datafusion_expr::{WindowFrame, WindowFrameUnits};
 use std::any::Any;
+use std::ops::Range;
 use std::sync::Arc;
 
 /// A window expr that takes the form of a built in window function
@@ -101,18 +102,19 @@ impl WindowExpr for BuiltInWindowExpr {
                 self.order_by.iter().map(|o| o.options).collect();
             let mut row_wise_results = vec![];
 
-            let length = batch.num_rows();
             let (values, order_bys) = self.get_values_orderbys(batch)?;
             let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
+            let range = Range { start: 0, end: 0 };
             // We iterate on each row to calculate window frame range and and window function result
-            for idx in 0..length {
+            for idx in 0..num_rows {
                 let range = window_frame_ctx.calculate_range(
                     &order_bys,
                     &sort_options,
                     num_rows,
                     idx,
+                    &range,
                 )?;
-                let value = evaluator.evaluate_inside_range(&values, range)?;
+                let value = evaluator.evaluate_inside_range(&values, &range)?;
                 row_wise_results.push(value);
             }
             ScalarValue::iter_to_array(row_wise_results.into_iter())
@@ -185,6 +187,7 @@ impl WindowExpr for BuiltInWindowExpr {
                         &sort_options,
                         num_rows,
                         idx,
+                        &state.window_frame_range,
                     )
                 } else {
                     evaluator.get_range(state, num_rows)
diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs
index c3c3b55d4..c40a4fa7d 100644
--- a/datafusion/physical-expr/src/window/nth_value.rs
+++ b/datafusion/physical-expr/src/window/nth_value.rs
@@ -176,13 +176,13 @@ impl PartitionEvaluator for NthValueEvaluator {
     }
 
     fn evaluate_stateful(&mut self, values: &[ArrayRef]) -> Result<ScalarValue> {
-        self.evaluate_inside_range(values, self.state.range.clone())
+        self.evaluate_inside_range(values, &self.state.range)
     }
 
     fn evaluate_inside_range(
         &self,
         values: &[ArrayRef],
-        range: Range<usize>,
+        range: &Range<usize>,
     ) -> Result<ScalarValue> {
         // FIRST_VALUE, LAST_VALUE, NTH_VALUE window functions take single column, values will have size 1
         let arr = &values[0];
@@ -227,7 +227,7 @@ mod tests {
         let evaluator = expr.create_evaluator()?;
         let values = expr.evaluate_args(&batch)?;
         let result = ranges
-            .into_iter()
+            .iter()
             .map(|range| evaluator.evaluate_inside_range(&values, range))
             .into_iter()
             .collect::<Result<Vec<ScalarValue>>>()?;
diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs
index e6cead76d..44fbb2d94 100644
--- a/datafusion/physical-expr/src/window/partition_evaluator.rs
+++ b/datafusion/physical-expr/src/window/partition_evaluator.rs
@@ -83,7 +83,7 @@ pub trait PartitionEvaluator: Debug + Send {
     fn evaluate_inside_range(
         &self,
         _values: &[ArrayRef],
-        _range: Range<usize>,
+        _range: &Range<usize>,
     ) -> Result<ScalarValue> {
         Err(DataFusionError::NotImplemented(
             "evaluate_inside_range is not implemented by default".into(),
diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs
index 587c313e3..a429f658c 100644
--- a/datafusion/physical-expr/src/window/sliding_aggregate.rs
+++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs
@@ -268,6 +268,7 @@ impl SlidingAggregateWindowExpr {
                 &sort_options,
                 length,
                 *idx,
+                last_range,
             )?;
             // Exit if range end index is length, need kind of flag to stop
             if cur_range.end == length && !is_end {
diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs
index 9c559cabd..9cde3cbdf 100644
--- a/datafusion/physical-expr/src/window/window_frame_state.rs
+++ b/datafusion/physical-expr/src/window/window_frame_state.rs
@@ -20,7 +20,9 @@
 
 use arrow::array::ArrayRef;
 use arrow::compute::kernels::sort::SortOptions;
-use datafusion_common::bisect::{bisect, find_bisect_point};
+use datafusion_common::utils::{
+    compare_rows, find_bisect_point, get_row_at_idx, search_in_slice,
+};
 use datafusion_common::{DataFusionError, Result, ScalarValue};
 use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
 use std::cmp::min;
@@ -69,6 +71,7 @@ impl<'a> WindowFrameContext<'a> {
         sort_options: &[SortOptions],
         length: usize,
         idx: usize,
+        last_range: &Range<usize>,
     ) -> Result<Range<usize>> {
         match *self {
             WindowFrameContext::Rows(window_frame) => {
@@ -85,6 +88,7 @@ impl<'a> WindowFrameContext<'a> {
                 sort_options,
                 length,
                 idx,
+                last_range,
             ),
             // sort_options is not used in GROUPS mode calculations as the inequality of two rows is the indicator
             // of a group change, and the ordering and the position of the nulls do not have impact on inequality.
@@ -170,6 +174,7 @@ impl WindowFrameStateRange {
         sort_options: &[SortOptions],
         length: usize,
         idx: usize,
+        last_range: &Range<usize>,
     ) -> Result<Range<usize>> {
         let start = match window_frame.start_bound {
             WindowFrameBound::Preceding(ref n) => {
@@ -182,6 +187,8 @@ impl WindowFrameStateRange {
                         sort_options,
                         idx,
                         Some(n),
+                        last_range,
+                        length,
                     )?
                 }
             }
@@ -194,6 +201,8 @@ impl WindowFrameStateRange {
                         sort_options,
                         idx,
                         None,
+                        last_range,
+                        length,
                     )?
                 }
             }
@@ -203,6 +212,8 @@ impl WindowFrameStateRange {
                     sort_options,
                     idx,
                     Some(n),
+                    last_range,
+                    length,
                 )?,
         };
         let end = match window_frame.end_bound {
@@ -212,6 +223,8 @@ impl WindowFrameStateRange {
                     sort_options,
                     idx,
                     Some(n),
+                    last_range,
+                    length,
                 )?,
             WindowFrameBound::CurrentRow => {
                 if range_columns.is_empty() {
@@ -222,6 +235,8 @@ impl WindowFrameStateRange {
                         sort_options,
                         idx,
                         None,
+                        last_range,
+                        length,
                     )?
                 }
             }
@@ -235,6 +250,8 @@ impl WindowFrameStateRange {
                         sort_options,
                         idx,
                         Some(n),
+                        last_range,
+                        length,
                     )?
                 }
             }
@@ -243,19 +260,18 @@ impl WindowFrameStateRange {
     }
 
     /// This function does the heavy lifting when finding range boundaries. It is meant to be
-    /// called twice, in succession, to get window frame start and end indices (with `BISECT_SIDE`
-    /// supplied as false and true, respectively).
-    fn calculate_index_of_row<const BISECT_SIDE: bool, const SEARCH_SIDE: bool>(
+    /// called twice, in succession, to get window frame start and end indices (with `SIDE`
+    /// supplied as true and false, respectively).
+    fn calculate_index_of_row<const SIDE: bool, const SEARCH_SIDE: bool>(
         &mut self,
         range_columns: &[ArrayRef],
         sort_options: &[SortOptions],
         idx: usize,
         delta: Option<&ScalarValue>,
+        last_range: &Range<usize>,
+        length: usize,
     ) -> Result<usize> {
-        let current_row_values = range_columns
-            .iter()
-            .map(|col| ScalarValue::try_from_array(col, idx))
-            .collect::<Result<Vec<ScalarValue>>>()?;
+        let current_row_values = get_row_at_idx(range_columns, idx)?;
         let end_range = if let Some(delta) = delta {
             let is_descending: bool = sort_options
                 .first()
@@ -285,8 +301,16 @@ impl WindowFrameStateRange {
         } else {
             current_row_values
         };
-        // `BISECT_SIDE` true means bisect_left, false means bisect_right
-        bisect::<BISECT_SIDE>(range_columns, &end_range, sort_options)
+        let search_start = if SIDE {
+            last_range.start
+        } else {
+            last_range.end
+        };
+        let compare_fn = |current: &[ScalarValue], target: &[ScalarValue]| {
+            let cmp = compare_rows(current, target, sort_options)?;
+            Ok(if SIDE { cmp.is_lt() } else { cmp.is_le() })
+        };
+        search_in_slice(range_columns, &end_range, compare_fn, search_start, length)
     }
 }