You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/11/04 15:04:00 UTC

[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #4078: Custom window frame support extended to built-in window functions

alamb commented on code in PR #4078:
URL: https://github.com/apache/arrow-datafusion/pull/4078#discussion_r1014125149


##########
datafusion/physical-expr/src/window/aggregate.rs:
##########
@@ -103,368 +82,86 @@ impl WindowExpr for AggregateWindowExpr {
     }
 
     fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
-        let num_rows = batch.num_rows();
+        let partition_columns = self.partition_columns(batch)?;
         let partition_points =
-            self.evaluate_partition_points(num_rows, &self.partition_columns(batch)?)?;
+            self.evaluate_partition_points(batch.num_rows(), &partition_columns)?;
         let values = self.evaluate_args(batch)?;
 
+        let sort_options: Vec<SortOptions> =
+            self.order_by.iter().map(|o| o.options).collect();
         let columns = self.sort_columns(batch)?;
-        let array_refs: Vec<&ArrayRef> = columns.iter().map(|s| &s.values).collect();
+        let order_columns: Vec<&ArrayRef> = columns.iter().map(|s| &s.values).collect();
         // Sort values, this will make the same partitions consecutive. Also, within the partition
         // range, values will be sorted.
-        let results = partition_points
-            .iter()
-            .map(|partition_range| {
-                let mut window_accumulators = self.create_accumulator()?;
-                Ok(vec![window_accumulators.scan(
-                    &values,
-                    &array_refs,
-                    partition_range,
-                )?])
-            })
-            .collect::<Result<Vec<Vec<ArrayRef>>>>()?
-            .into_iter()
-            .flatten()
-            .collect::<Vec<ArrayRef>>();
-        let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
-        concat(&results).map_err(DataFusionError::ArrowError)
-    }
-
-    fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
-        &self.partition_by
-    }
-
-    fn order_by(&self) -> &[PhysicalSortExpr] {
-        &self.order_by
-    }
-}
-
-fn calculate_index_of_row<const BISECT_SIDE: bool, const SEARCH_SIDE: bool>(
-    range_columns: &[ArrayRef],
-    sort_options: &[SortOptions],
-    idx: usize,
-    delta: Option<&ScalarValue>,
-) -> Result<usize> {
-    let current_row_values = range_columns
-        .iter()
-        .map(|col| ScalarValue::try_from_array(col, idx))
-        .collect::<Result<Vec<ScalarValue>>>()?;
-    let end_range = if let Some(delta) = delta {
-        let is_descending: bool = sort_options
-            .first()
-            .ok_or_else(|| DataFusionError::Internal("Array is empty".to_string()))?
-            .descending;
-
-        current_row_values
-            .iter()
-            .map(|value| {
-                if value.is_null() {
-                    return Ok(value.clone());
-                }
-                if SEARCH_SIDE == is_descending {
-                    // TODO: Handle positive overflows
-                    value.add(delta)
-                } else if value.is_unsigned() && value < delta {
-                    // NOTE: This gets a polymorphic zero without having long coercion code for ScalarValue.
-                    //       If we decide to implement a "default" construction mechanism for ScalarValue,
-                    //       change the following statement to use that.
-                    value.sub(value)
+        let order_bys = &order_columns[self.partition_by.len()..];
+        let window_frame = if !order_bys.is_empty() && self.window_frame.is_none() {
+            // OVER (ORDER BY a) case
+            // We create an implicit window for ORDER BY.
+            Some(Arc::new(WindowFrame::default()))
+        } else {
+            self.window_frame.clone()
+        };
+        let mut row_wise_results: Vec<ScalarValue> = vec![];
+        for partition_range in &partition_points {

Review Comment:
   This reorganization is very nice and makes the code much easier to read . Very nice 👍 



##########
datafusion/physical-expr/src/window/window_expr.rs:
##########
@@ -110,4 +115,208 @@ pub trait WindowExpr: Send + Sync + Debug {
         sort_columns.extend(order_by_columns);
         Ok(sort_columns)
     }
+
+    /// We use start and end bounds to calculate current row's starting and ending range.
+    /// This function supports different modes, but we currently do not support window calculation for GROUPS inside window frames.
+    fn calculate_range(
+        &self,
+        window_frame: &Option<Arc<WindowFrame>>,
+        range_columns: &[ArrayRef],
+        sort_options: &[SortOptions],
+        length: usize,
+        idx: usize,
+    ) -> Result<(usize, usize)> {
+        if let Some(window_frame) = window_frame {
+            match window_frame.units {
+                WindowFrameUnits::Range => {
+                    let start = match &window_frame.start_bound {
+                        // UNBOUNDED PRECEDING
+                        WindowFrameBound::Preceding(n) => {
+                            if n.is_null() {
+                                0
+                            } else {
+                                calculate_index_of_row::<true, true>(
+                                    range_columns,
+                                    sort_options,
+                                    idx,
+                                    Some(n),
+                                )?
+                            }
+                        }
+                        WindowFrameBound::CurrentRow => {
+                            if range_columns.is_empty() {
+                                0
+                            } else {
+                                calculate_index_of_row::<true, true>(
+                                    range_columns,
+                                    sort_options,
+                                    idx,
+                                    None,
+                                )?
+                            }
+                        }
+                        WindowFrameBound::Following(n) => {
+                            calculate_index_of_row::<true, false>(
+                                range_columns,
+                                sort_options,
+                                idx,
+                                Some(n),
+                            )?
+                        }
+                    };
+                    let end = match &window_frame.end_bound {
+                        WindowFrameBound::Preceding(n) => {
+                            calculate_index_of_row::<false, true>(
+                                range_columns,
+                                sort_options,
+                                idx,
+                                Some(n),
+                            )?
+                        }
+                        WindowFrameBound::CurrentRow => {
+                            if range_columns.is_empty() {
+                                length
+                            } else {
+                                calculate_index_of_row::<false, false>(
+                                    range_columns,
+                                    sort_options,
+                                    idx,
+                                    None,
+                                )?
+                            }
+                        }
+                        WindowFrameBound::Following(n) => {
+                            if n.is_null() {
+                                // UNBOUNDED FOLLOWING
+                                length
+                            } else {
+                                calculate_index_of_row::<false, false>(
+                                    range_columns,
+                                    sort_options,
+                                    idx,
+                                    Some(n),
+                                )?
+                            }
+                        }
+                    };
+                    Ok((start, end))
+                }
+                WindowFrameUnits::Rows => {
+                    let start = match window_frame.start_bound {
+                        // UNBOUNDED PRECEDING
+                        WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0,
+                        WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
+                            if idx >= n as usize {
+                                idx - n as usize
+                            } else {
+                                0
+                            }
+                        }
+                        WindowFrameBound::Preceding(_) => {
+                            return Err(DataFusionError::Internal(
+                                "Rows should be Uint".to_string(),
+                            ))
+                        }
+                        WindowFrameBound::CurrentRow => idx,
+                        // UNBOUNDED FOLLOWING
+                        WindowFrameBound::Following(ScalarValue::UInt64(None)) => {
+                            return Err(DataFusionError::Internal(format!(
+                                "Frame start cannot be UNBOUNDED FOLLOWING '{:?}'",
+                                window_frame
+                            )))
+                        }
+                        WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
+                            min(idx + n as usize, length)
+                        }
+                        WindowFrameBound::Following(_) => {
+                            return Err(DataFusionError::Internal(
+                                "Rows should be Uint".to_string(),
+                            ))
+                        }
+                    };
+                    let end = match window_frame.end_bound {
+                        // UNBOUNDED PRECEDING
+                        WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => {
+                            return Err(DataFusionError::Internal(format!(
+                                "Frame end cannot be UNBOUNDED PRECEDING '{:?}'",
+                                window_frame
+                            )))
+                        }
+                        WindowFrameBound::Preceding(ScalarValue::UInt64(Some(n))) => {
+                            if idx >= n as usize {
+                                idx - n as usize + 1
+                            } else {
+                                0
+                            }
+                        }
+                        WindowFrameBound::Preceding(_) => {
+                            return Err(DataFusionError::Internal(
+                                "Rows should be Uint".to_string(),
+                            ))
+                        }
+                        WindowFrameBound::CurrentRow => idx + 1,
+                        // UNBOUNDED FOLLOWING
+                        WindowFrameBound::Following(ScalarValue::UInt64(None)) => length,
+                        WindowFrameBound::Following(ScalarValue::UInt64(Some(n))) => {
+                            min(idx + n as usize + 1, length)
+                        }
+                        WindowFrameBound::Following(_) => {
+                            return Err(DataFusionError::Internal(
+                                "Rows should be Uint".to_string(),
+                            ))
+                        }
+                    };
+                    Ok((start, end))
+                }
+                WindowFrameUnits::Groups => Err(DataFusionError::NotImplemented(
+                    "Window frame for groups is not implemented".to_string(),

Review Comment:
   👍 



##########
datafusion/physical-expr/src/window/window_expr.rs:
##########
@@ -110,4 +115,208 @@ pub trait WindowExpr: Send + Sync + Debug {
         sort_columns.extend(order_by_columns);
         Ok(sort_columns)
     }
+
+    /// We use start and end bounds to calculate current row's starting and ending range.
+    /// This function supports different modes, but we currently do not support window calculation for GROUPS inside window frames.
+    fn calculate_range(

Review Comment:
   I find this logic to be very straightforward and easy to follow 👍 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org