You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/06/16 10:25:43 UTC
[arrow-datafusion] branch master updated: add window function
implementation with order_by clause (#520)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 51e5445 add window function implementation with order_by clause (#520)
51e5445 is described below
commit 51e5445fa51cef4f72df5db7804906a729fc5aa6
Author: Jiayu Liu <Ji...@users.noreply.github.com>
AuthorDate: Wed Jun 16 18:25:35 2021 +0800
add window function implementation with order_by clause (#520)
---
datafusion/src/execution/context.rs | 55 ++++-
.../src/physical_plan/expressions/nth_value.rs | 137 ++++--------
.../src/physical_plan/expressions/row_number.rs | 89 +-------
datafusion/src/physical_plan/hash_aggregate.rs | 4 +-
datafusion/src/physical_plan/mod.rs | 130 +++++------
datafusion/src/physical_plan/planner.rs | 15 +-
datafusion/src/physical_plan/window_functions.rs | 14 +-
datafusion/src/physical_plan/windows.rs | 244 +++++++++++----------
datafusion/src/scalar.rs | 2 +-
datafusion/src/sql/planner.rs | 2 +-
datafusion/tests/sql.rs | 147 +++++++++++--
.../sqls/simple_window_ordered_aggregation.sql | 26 +++
integration-tests/test_psql_parity.py | 2 +-
13 files changed, 476 insertions(+), 391 deletions(-)
diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs
index f09d7f4..1835244 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -1273,7 +1273,17 @@ mod tests {
#[tokio::test]
async fn window() -> Result<()> {
let results = execute(
- "SELECT c1, c2, SUM(c2) OVER (), COUNT(c2) OVER (), MAX(c2) OVER (), MIN(c2) OVER (), AVG(c2) OVER () FROM test ORDER BY c1, c2 LIMIT 5",
+ "SELECT \
+ c1, \
+ c2, \
+ SUM(c2) OVER (), \
+ COUNT(c2) OVER (), \
+ MAX(c2) OVER (), \
+ MIN(c2) OVER (), \
+ AVG(c2) OVER () \
+ FROM test \
+ ORDER BY c1, c2 \
+ LIMIT 5",
4,
)
.await?;
@@ -1300,6 +1310,49 @@ mod tests {
}
#[tokio::test]
+ async fn window_order_by() -> Result<()> {
+ let results = execute(
+ "SELECT \
+ c1, \
+ c2, \
+ ROW_NUMBER() OVER (ORDER BY c1, c2), \
+ FIRST_VALUE(c2) OVER (ORDER BY c1, c2), \
+ LAST_VALUE(c2) OVER (ORDER BY c1, c2), \
+ NTH_VALUE(c2, 2) OVER (ORDER BY c1, c2), \
+ SUM(c2) OVER (ORDER BY c1, c2), \
+ COUNT(c2) OVER (ORDER BY c1, c2), \
+ MAX(c2) OVER (ORDER BY c1, c2), \
+ MIN(c2) OVER (ORDER BY c1, c2), \
+ AVG(c2) OVER (ORDER BY c1, c2) \
+ FROM test \
+ ORDER BY c1, c2 \
+ LIMIT 5",
+ 4,
+ )
+ .await?;
+ // result in one batch, although e.g. having 2 batches do not change
+ // result semantics, having a len=1 assertion upfront keeps surprises
+ // at bay
+ assert_eq!(results.len(), 1);
+
+ let expected = vec![
+ "+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+",
+ "| c1 | c2 | ROW_NUMBER() | FIRST_VALUE(c2) | LAST_VALUE(c2) | NTH_VALUE(c2,Int64(2)) | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |",
+ "+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+",
+ "| 0 | 1 | 1 | 1 | 10 | 2 | 1 | 1 | 1 | 1 | 1 |",
+ "| 0 | 2 | 2 | 1 | 10 | 2 | 3 | 2 | 2 | 1 | 1.5 |",
+ "| 0 | 3 | 3 | 1 | 10 | 2 | 6 | 3 | 3 | 1 | 2 |",
+ "| 0 | 4 | 4 | 1 | 10 | 2 | 10 | 4 | 4 | 1 | 2.5 |",
+ "| 0 | 5 | 5 | 1 | 10 | 2 | 15 | 5 | 5 | 1 | 3 |",
+ "+----+----+--------------+-----------------+----------------+------------------------+---------+-----------+---------+---------+---------+",
+ ];
+
+ // window function shall respect ordering
+ assert_batches_eq!(expected, &results);
+ Ok(())
+ }
+
+ #[tokio::test]
async fn aggregate() -> Result<()> {
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?;
assert_eq!(results.len(), 1);
diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs
index fb0e79f..98083fa 100644
--- a/datafusion/src/physical_plan/expressions/nth_value.rs
+++ b/datafusion/src/physical_plan/expressions/nth_value.rs
@@ -18,13 +18,11 @@
//! Defines physical expressions that can evaluated at runtime during query execution
use crate::error::{DataFusionError, Result};
-use crate::physical_plan::{
- window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator,
-};
+use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr};
use crate::scalar::ScalarValue;
+use arrow::array::{new_empty_array, ArrayRef};
use arrow::datatypes::{DataType, Field};
use std::any::Any;
-use std::convert::TryFrom;
use std::sync::Arc;
/// nth_value kind
@@ -113,54 +111,32 @@ impl BuiltInWindowFunctionExpr for NthValue {
&self.name
}
- fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
- Ok(Box::new(NthValueAccumulator::try_new(
- self.kind,
- self.data_type.clone(),
- )?))
- }
-}
-
-#[derive(Debug)]
-struct NthValueAccumulator {
- kind: NthValueKind,
- offset: u32,
- value: ScalarValue,
-}
-
-impl NthValueAccumulator {
- /// new count accumulator
- pub fn try_new(kind: NthValueKind, data_type: DataType) -> Result<Self> {
- Ok(Self {
- kind,
- offset: 0,
- // null value of that data_type by default
- value: ScalarValue::try_from(&data_type)?,
- })
- }
-}
-
-impl WindowAccumulator for NthValueAccumulator {
- fn scan(&mut self, values: &[ScalarValue]) -> Result<Option<ScalarValue>> {
- self.offset += 1;
- match self.kind {
- NthValueKind::Last => {
- self.value = values[0].clone();
- }
- NthValueKind::First if self.offset == 1 => {
- self.value = values[0].clone();
- }
- NthValueKind::Nth(n) if self.offset == n => {
- self.value = values[0].clone();
- }
- _ => {}
+ fn evaluate(&self, num_rows: usize, values: &[ArrayRef]) -> Result<ArrayRef> {
+ if values.is_empty() {
+ return Err(DataFusionError::Execution(format!(
+ "No arguments supplied to {}",
+ self.name()
+ )));
}
-
- Ok(None)
- }
-
- fn evaluate(&self) -> Result<Option<ScalarValue>> {
- Ok(Some(self.value.clone()))
+ let value = &values[0];
+ if value.len() != num_rows {
+ return Err(DataFusionError::Execution(format!(
+ "Invalid data supplied to {}, expect {} rows, got {} rows",
+ self.name(),
+ num_rows,
+ value.len()
+ )));
+ }
+ if num_rows == 0 {
+ return Ok(new_empty_array(value.data_type()));
+ }
+ let index: usize = match self.kind {
+ NthValueKind::First => 0,
+ NthValueKind::Last => (num_rows as usize) - 1,
+ NthValueKind::Nth(n) => (n as usize) - 1,
+ };
+ let value = ScalarValue::try_from_array(value, index)?;
+ Ok(value.to_array_of_size(num_rows))
}
}
@@ -172,68 +148,47 @@ mod tests {
use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};
- fn test_i32_result(expr: Arc<NthValue>, expected: i32) -> Result<()> {
+ fn test_i32_result(expr: NthValue, expected: Vec<i32>) -> Result<()> {
let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8]));
+ let values = vec![arr];
let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]);
- let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;
-
- let mut acc = expr.create_accumulator()?;
- let expr = expr.expressions();
- let values = expr
- .iter()
- .map(|e| e.evaluate(&batch))
- .map(|r| r.map(|v| v.into_array(batch.num_rows())))
- .collect::<Result<Vec<_>>>()?;
- let result = acc.scan_batch(batch.num_rows(), &values)?;
- assert_eq!(false, result.is_some());
- let result = acc.evaluate()?;
- assert_eq!(Some(ScalarValue::Int32(Some(expected))), result);
+ let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?;
+ let result = expr.evaluate(batch.num_rows(), &values)?;
+ let result = result.as_any().downcast_ref::<Int32Array>().unwrap();
+ let result = result.values();
+ assert_eq!(expected, result);
Ok(())
}
#[test]
fn first_value() -> Result<()> {
- let first_value = Arc::new(NthValue::first_value(
- "first_value".to_owned(),
- col("arr"),
- DataType::Int32,
- ));
- test_i32_result(first_value, 1)?;
+ let first_value =
+ NthValue::first_value("first_value".to_owned(), col("arr"), DataType::Int32);
+ test_i32_result(first_value, vec![1; 8])?;
Ok(())
}
#[test]
fn last_value() -> Result<()> {
- let last_value = Arc::new(NthValue::last_value(
- "last_value".to_owned(),
- col("arr"),
- DataType::Int32,
- ));
- test_i32_result(last_value, 8)?;
+ let last_value =
+ NthValue::last_value("last_value".to_owned(), col("arr"), DataType::Int32);
+ test_i32_result(last_value, vec![8; 8])?;
Ok(())
}
#[test]
fn nth_value_1() -> Result<()> {
- let nth_value = Arc::new(NthValue::nth_value(
- "nth_value".to_owned(),
- col("arr"),
- DataType::Int32,
- 1,
- )?);
- test_i32_result(nth_value, 1)?;
+ let nth_value =
+ NthValue::nth_value("nth_value".to_owned(), col("arr"), DataType::Int32, 1)?;
+ test_i32_result(nth_value, vec![1; 8])?;
Ok(())
}
#[test]
fn nth_value_2() -> Result<()> {
- let nth_value = Arc::new(NthValue::nth_value(
- "nth_value".to_owned(),
- col("arr"),
- DataType::Int32,
- 2,
- )?);
- test_i32_result(nth_value, -2)?;
+ let nth_value =
+ NthValue::nth_value("nth_value".to_owned(), col("arr"), DataType::Int32, 2)?;
+ test_i32_result(nth_value, vec![-2; 8])?;
Ok(())
}
}
diff --git a/datafusion/src/physical_plan/expressions/row_number.rs b/datafusion/src/physical_plan/expressions/row_number.rs
index eaf9b21..0444ee9 100644
--- a/datafusion/src/physical_plan/expressions/row_number.rs
+++ b/datafusion/src/physical_plan/expressions/row_number.rs
@@ -18,10 +18,7 @@
//! Defines physical expression for `row_number` that can evaluated at runtime during query execution
use crate::error::Result;
-use crate::physical_plan::{
- window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator,
-};
-use crate::scalar::ScalarValue;
+use crate::physical_plan::{window_functions::BuiltInWindowFunctionExpr, PhysicalExpr};
use arrow::array::{ArrayRef, UInt64Array};
use arrow::datatypes::{DataType, Field};
use std::any::Any;
@@ -60,46 +57,10 @@ impl BuiltInWindowFunctionExpr for RowNumber {
self.name.as_str()
}
- fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
- Ok(Box::new(RowNumberAccumulator::new()))
- }
-}
-
-#[derive(Debug)]
-struct RowNumberAccumulator {
- row_number: u64,
-}
-
-impl RowNumberAccumulator {
- /// new row_number accumulator
- pub fn new() -> Self {
- // row number is 1 based
- Self { row_number: 1 }
- }
-}
-
-impl WindowAccumulator for RowNumberAccumulator {
- fn scan(&mut self, _values: &[ScalarValue]) -> Result<Option<ScalarValue>> {
- let result = Some(ScalarValue::UInt64(Some(self.row_number)));
- self.row_number += 1;
- Ok(result)
- }
-
- fn scan_batch(
- &mut self,
- num_rows: usize,
- _values: &[ArrayRef],
- ) -> Result<Option<ArrayRef>> {
- let new_row_number = self.row_number + (num_rows as u64);
- // TODO: probably would be nice to have a (optimized) kernel for this at some point to
- // generate an array like this.
- let result = UInt64Array::from_iter_values(self.row_number..new_row_number);
- self.row_number = new_row_number;
- Ok(Some(Arc::new(result)))
- }
-
- fn evaluate(&self) -> Result<Option<ScalarValue>> {
- Ok(None)
+ fn evaluate(&self, num_rows: usize, _values: &[ArrayRef]) -> Result<ArrayRef> {
+ Ok(Arc::new(UInt64Array::from_iter_values(
+ (1..num_rows + 1).map(|i| i as u64),
+ )))
}
}
@@ -117,27 +78,11 @@ mod tests {
]));
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;
-
- let row_number = Arc::new(RowNumber::new("row_number".to_owned()));
-
- let mut acc = row_number.create_accumulator()?;
- let expr = row_number.expressions();
- let values = expr
- .iter()
- .map(|e| e.evaluate(&batch))
- .map(|r| r.map(|v| v.into_array(batch.num_rows())))
- .collect::<Result<Vec<_>>>()?;
-
- let result = acc.scan_batch(batch.num_rows(), &values)?;
- assert_eq!(true, result.is_some());
-
- let result = result.unwrap();
+ let row_number = RowNumber::new("row_number".to_owned());
+ let result = row_number.evaluate(batch.num_rows(), &[])?;
let result = result.as_any().downcast_ref::<UInt64Array>().unwrap();
let result = result.values();
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);
-
- let result = acc.evaluate()?;
- assert_eq!(false, result.is_some());
Ok(())
}
@@ -148,27 +93,11 @@ mod tests {
]));
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;
-
- let row_number = Arc::new(RowNumber::new("row_number".to_owned()));
-
- let mut acc = row_number.create_accumulator()?;
- let expr = row_number.expressions();
- let values = expr
- .iter()
- .map(|e| e.evaluate(&batch))
- .map(|r| r.map(|v| v.into_array(batch.num_rows())))
- .collect::<Result<Vec<_>>>()?;
-
- let result = acc.scan_batch(batch.num_rows(), &values)?;
- assert_eq!(true, result.is_some());
-
- let result = result.unwrap();
+ let row_number = RowNumber::new("row_number".to_owned());
+ let result = row_number.evaluate(batch.num_rows(), &[])?;
let result = result.as_any().downcast_ref::<UInt64Array>().unwrap();
let result = result.values();
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);
-
- let result = acc.evaluate()?;
- assert_eq!(false, result.is_some());
Ok(())
}
}
diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs
index 453d500..f1611eb 100644
--- a/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/datafusion/src/physical_plan/hash_aggregate.rs
@@ -500,7 +500,7 @@ fn dictionary_create_key_for_col<K: ArrowDictionaryKeyType>(
let dict_col = col.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
// look up the index in the values dictionary
- let keys_col = dict_col.keys_array();
+ let keys_col = dict_col.keys();
let values_index = keys_col.value(row).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
"Can not convert index to usize in dictionary of type creating group by value {:?}",
@@ -1083,7 +1083,7 @@ fn dictionary_create_group_by_value<K: ArrowDictionaryKeyType>(
let dict_col = col.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
// look up the index in the values dictionary
- let keys_col = dict_col.keys_array();
+ let keys_col = dict_col.keys();
let values_index = keys_col.value(row).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
"Can not convert index to usize in dictionary of type creating group by value {:?}",
diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs
index 2dcba80..713956f 100644
--- a/datafusion/src/physical_plan/mod.rs
+++ b/datafusion/src/physical_plan/mod.rs
@@ -17,17 +17,16 @@
//! Traits for physical query plan, supporting parallel execution for partitioned relations.
-use std::fmt;
-use std::fmt::{Debug, Display};
-use std::sync::atomic::{AtomicUsize, Ordering};
-use std::sync::Arc;
-
+use self::{display::DisplayableExecutionPlan, merge::MergeExec};
use crate::execution::context::ExecutionContextState;
use crate::logical_plan::LogicalPlan;
+use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::{
error::{DataFusionError, Result},
scalar::ScalarValue,
};
+use arrow::compute::kernels::partition::lexicographical_partition_ranges;
+use arrow::compute::kernels::sort::{SortColumn, SortOptions};
use arrow::datatypes::{DataType, Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
@@ -35,10 +34,13 @@ use arrow::{array::ArrayRef, datatypes::Field};
use async_trait::async_trait;
pub use display::DisplayFormatType;
use futures::stream::Stream;
-use std::{any::Any, pin::Pin};
-
-use self::{display::DisplayableExecutionPlan, merge::MergeExec};
use hashbrown::HashMap;
+use std::fmt;
+use std::fmt::{Debug, Display};
+use std::ops::Range;
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::Arc;
+use std::{any::Any, pin::Pin};
/// Trait for types that stream [arrow::record_batch::RecordBatch]
pub trait RecordBatchStream: Stream<Item = ArrowResult<RecordBatch>> {
@@ -465,15 +467,65 @@ pub trait WindowExpr: Send + Sync + Debug {
"WindowExpr: default name"
}
- /// the accumulator used to accumulate values from the expressions.
- /// the accumulator expects the same number of arguments as `expressions` and must
- /// return states with the same description as `state_fields`
- fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>>;
-
/// expressions that are passed to the WindowAccumulator.
/// Functions which take a single input argument, such as `sum`, return a single [`Expr`],
/// others (e.g. `cov`) return many.
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
+
+ /// evaluate the window function arguments against the batch and return
+ /// array ref, normally the resulting vec is a single element one.
+ fn evaluate_args(&self, batch: &RecordBatch) -> Result<Vec<ArrayRef>> {
+ self.expressions()
+ .iter()
+ .map(|e| e.evaluate(batch))
+ .map(|r| r.map(|v| v.into_array(batch.num_rows())))
+ .collect()
+ }
+
+ /// evaluate the window function values against the batch
+ fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
+
+ /// evaluate the sort partition points
+ fn evaluate_sort_partition_points(
+ &self,
+ batch: &RecordBatch,
+ ) -> Result<Vec<Range<usize>>> {
+ let sort_columns = self.sort_columns(batch)?;
+ if sort_columns.is_empty() {
+ Ok(vec![Range {
+ start: 0,
+ end: batch.num_rows(),
+ }])
+ } else {
+ lexicographical_partition_ranges(&sort_columns)
+ .map_err(DataFusionError::ArrowError)
+ }
+ }
+
+ /// expressions that's from the window function's partition by clause, empty if absent
+ fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>];
+
+ /// expressions that's from the window function's order by clause, empty if absent
+ fn order_by(&self) -> &[PhysicalSortExpr];
+
+ /// get sort columns that can be used for partitioning, empty if absent
+ fn sort_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
+ self.partition_by()
+ .iter()
+ .map(|expr| {
+ PhysicalSortExpr {
+ expr: expr.clone(),
+ options: SortOptions::default(),
+ }
+ .evaluate_to_sort_column(batch)
+ })
+ .chain(
+ self.order_by()
+ .iter()
+ .map(|e| e.evaluate_to_sort_column(batch)),
+ )
+ .collect()
+ }
}
/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and
@@ -528,58 +580,6 @@ pub trait Accumulator: Send + Sync + Debug {
fn evaluate(&self) -> Result<ScalarValue>;
}
-/// A window accumulator represents a stateful object that lives throughout the evaluation of multiple
-/// rows and generically accumulates values.
-///
-/// An accumulator knows how to:
-/// * update its state from inputs via `update`
-/// * convert its internal state to a vector of scalar values
-/// * update its state from multiple accumulators' states via `merge`
-/// * compute the final value from its internal state via `evaluate`
-pub trait WindowAccumulator: Send + Sync + Debug {
- /// scans the accumulator's state from a vector of scalars, similar to Accumulator it also
- /// optionally generates values.
- fn scan(&mut self, values: &[ScalarValue]) -> Result<Option<ScalarValue>>;
-
- /// scans the accumulator's state from a vector of arrays.
- fn scan_batch(
- &mut self,
- num_rows: usize,
- values: &[ArrayRef],
- ) -> Result<Option<ArrayRef>> {
- if values.is_empty() {
- return Ok(None);
- };
- // transpose columnar to row based so that we can apply window
- let result = (0..num_rows)
- .map(|index| {
- let v = values
- .iter()
- .map(|array| ScalarValue::try_from_array(array, index))
- .collect::<Result<Vec<_>>>()?;
- self.scan(&v)
- })
- .collect::<Result<Vec<Option<ScalarValue>>>>()?
- .into_iter()
- .collect::<Option<Vec<ScalarValue>>>();
-
- Ok(match result {
- Some(arr) if num_rows == arr.len() => Some(ScalarValue::iter_to_array(arr)?),
- None => None,
- Some(arr) => {
- return Err(DataFusionError::Internal(format!(
- "expect scan batch to return {:?} rows, but got {:?}",
- num_rows,
- arr.len()
- )))
- }
- })
- }
-
- /// returns its value based on its current state.
- fn evaluate(&self) -> Result<Option<ScalarValue>>;
-}
-
pub mod aggregates;
pub mod array_expressions;
pub mod coalesce_batches;
diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs
index 31b3749..1121c28 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -143,7 +143,12 @@ impl DefaultPhysicalPlanner {
LogicalPlan::Window {
input, window_expr, ..
} => {
- // Initially need to perform the aggregate and then merge the partitions
+ if window_expr.is_empty() {
+ return Err(DataFusionError::Internal(
+ "Impossibly got empty window expression".to_owned(),
+ ));
+ }
+
let input_exec = self.create_initial_plan(input, ctx_state)?;
let input_schema = input_exec.schema();
@@ -364,7 +369,7 @@ impl DefaultPhysicalPlanner {
let left_expr = keys.iter().map(|x| col(&x.0)).collect();
let right_expr = keys.iter().map(|x| col(&x.1)).collect();
- // Use hash partition by defualt to parallelize hash joins
+ // Use hash partition by default to parallelize hash joins
Ok(Arc::new(HashJoinExec::try_new(
Arc::new(RepartitionExec::try_new(
left,
@@ -776,12 +781,6 @@ impl DefaultPhysicalPlanner {
.to_owned(),
));
}
- if !order_by.is_empty() {
- return Err(DataFusionError::NotImplemented(
- "window expression with non-empty order by clause is not yet supported"
- .to_owned(),
- ));
- }
if window_frame.is_some() {
return Err(DataFusionError::NotImplemented(
"window expression with window frame definition is not yet supported"
diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs
index e6afcaa..4f56aa7 100644
--- a/datafusion/src/physical_plan/window_functions.rs
+++ b/datafusion/src/physical_plan/window_functions.rs
@@ -20,11 +20,12 @@
//!
//! see also https://www.postgresql.org/docs/current/functions-window.html
+use crate::arrow::array::ArrayRef;
use crate::arrow::datatypes::Field;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{
aggregates, aggregates::AggregateFunction, functions::Signature,
- type_coercion::data_types, PhysicalExpr, WindowAccumulator,
+ type_coercion::data_types, PhysicalExpr,
};
use arrow::datatypes::DataType;
use std::any::Any;
@@ -207,7 +208,10 @@ pub(super) fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature {
}
}
-/// A window expression that is a built-in window function
+/// A window expression that is a built-in window function.
+///
+/// Note that unlike aggregation based window functions, built-in window functions normally ignore
+/// window frame spec, with the exception of first_value, last_value, and nth_value.
pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug {
/// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be
/// downcast to a specific implementation.
@@ -226,10 +230,8 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug {
"BuiltInWindowFunctionExpr: default name"
}
- /// the accumulator used to accumulate values from the expressions.
- /// the accumulator expects the same number of arguments as `expressions` and must
- /// return states with the same description as `state_fields`
- fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>>;
+ /// Evaluate the built-in window function against the number of rows and the arguments
+ fn evaluate(&self, num_rows: usize, values: &[ArrayRef]) -> Result<ArrayRef>;
}
#[cfg(test)]
diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs
index f95dd44..e557097 100644
--- a/datafusion/src/physical_plan/windows.rs
+++ b/datafusion/src/physical_plan/windows.rs
@@ -18,8 +18,7 @@
//! Execution plan for window functions
use crate::error::{DataFusionError, Result};
-
-use crate::logical_plan::window_frames::WindowFrame;
+use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits};
use crate::physical_plan::{
aggregates, common,
expressions::{Literal, NthValue, PhysicalSortExpr, RowNumber},
@@ -28,9 +27,9 @@ use crate::physical_plan::{
window_functions::BuiltInWindowFunctionExpr,
window_functions::{BuiltInWindowFunction, WindowFunction},
Accumulator, AggregateExpr, Distribution, ExecutionPlan, Partitioning, PhysicalExpr,
- RecordBatchStream, SendableRecordBatchStream, WindowAccumulator, WindowExpr,
+ RecordBatchStream, SendableRecordBatchStream, WindowExpr,
};
-use crate::scalar::ScalarValue;
+use arrow::compute::concat;
use arrow::{
array::ArrayRef,
datatypes::{Field, Schema, SchemaRef},
@@ -43,6 +42,7 @@ use futures::Future;
use pin_project_lite::pin_project;
use std::any::Any;
use std::convert::TryInto;
+use std::ops::Range;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
@@ -65,12 +65,9 @@ pub fn create_window_expr(
fun: &WindowFunction,
name: String,
args: &[Arc<dyn PhysicalExpr>],
- // https://github.com/apache/arrow-datafusion/issues/299
- _partition_by: &[Arc<dyn PhysicalExpr>],
- // https://github.com/apache/arrow-datafusion/issues/360
- _order_by: &[PhysicalSortExpr],
- // https://github.com/apache/arrow-datafusion/issues/361
- _window_frame: Option<WindowFrame>,
+ partition_by: &[Arc<dyn PhysicalExpr>],
+ order_by: &[PhysicalSortExpr],
+ window_frame: Option<WindowFrame>,
input_schema: &Schema,
) -> Result<Arc<dyn WindowExpr>> {
Ok(match fun {
@@ -82,9 +79,15 @@ pub fn create_window_expr(
input_schema,
name,
)?,
+ partition_by: partition_by.to_vec(),
+ order_by: order_by.to_vec(),
+ window_frame,
}),
WindowFunction::BuiltInWindowFunction(fun) => Arc::new(BuiltInWindowExpr {
window: create_built_in_window_expr(fun, args, input_schema, name)?,
+ partition_by: partition_by.to_vec(),
+ order_by: order_by.to_vec(),
+ window_frame,
}),
})
}
@@ -136,6 +139,9 @@ fn create_built_in_window_expr(
#[derive(Debug)]
pub struct BuiltInWindowExpr {
window: Arc<dyn BuiltInWindowFunctionExpr>,
+ partition_by: Vec<Arc<dyn PhysicalExpr>>,
+ order_by: Vec<PhysicalSortExpr>,
+ window_frame: Option<WindowFrame>,
}
impl WindowExpr for BuiltInWindowExpr {
@@ -156,8 +162,20 @@ impl WindowExpr for BuiltInWindowExpr {
self.window.expressions()
}
- fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
- self.window.create_accumulator()
+ fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
+ &self.partition_by
+ }
+
+ fn order_by(&self) -> &[PhysicalSortExpr] {
+ &self.order_by
+ }
+
+ fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
+ // FIXME, for now we assume all the rows belong to the same partition, which will not be the
+ // case when partition_by is supported, in which case we'll parallelize the calls.
+ // See https://github.com/apache/arrow-datafusion/issues/299
+ let values = self.evaluate_args(batch)?;
+ self.window.evaluate(batch.num_rows(), &values)
}
}
@@ -165,22 +183,51 @@ impl WindowExpr for BuiltInWindowExpr {
#[derive(Debug)]
pub struct AggregateWindowExpr {
aggregate: Arc<dyn AggregateExpr>,
+ partition_by: Vec<Arc<dyn PhysicalExpr>>,
+ order_by: Vec<PhysicalSortExpr>,
+ window_frame: Option<WindowFrame>,
}
-#[derive(Debug)]
-struct AggregateWindowAccumulator {
- accumulator: Box<dyn Accumulator>,
-}
+impl AggregateWindowExpr {
+ /// the aggregate window function operates based on window frame, and by default the mode is
+ /// "range".
+ fn evaluation_mode(&self) -> WindowFrameUnits {
+ self.window_frame.unwrap_or_default().units
+ }
-impl WindowAccumulator for AggregateWindowAccumulator {
- fn scan(&mut self, values: &[ScalarValue]) -> Result<Option<ScalarValue>> {
- self.accumulator.update(values)?;
- Ok(None)
+ /// create a new accumulator based on the underlying aggregation function
+ fn create_accumulator(&self) -> Result<AggregateWindowAccumulator> {
+ let accumulator = self.aggregate.create_accumulator()?;
+ Ok(AggregateWindowAccumulator { accumulator })
}
- /// returns its value based on its current state.
- fn evaluate(&self) -> Result<Option<ScalarValue>> {
- Ok(Some(self.accumulator.evaluate()?))
+ /// peer based evaluation based on the fact that batch is pre-sorted given the sort columns
+ /// and then per partition point we'll evaluate the peer group (e.g. SUM or MAX gives the same
+ /// results for peers) and concatenate the results.
+ fn peer_based_evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
+ let sort_partition_points = self.evaluate_sort_partition_points(batch)?;
+ let mut window_accumulators = self.create_accumulator()?;
+ let values = self.evaluate_args(batch)?;
+ let results = sort_partition_points
+ .iter()
+ .map(|peer_range| window_accumulators.scan_peers(&values, peer_range))
+ .collect::<Result<Vec<_>>>()?;
+ let results = results.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
+ concat(&results).map_err(DataFusionError::ArrowError)
+ }
+
+ fn group_based_evaluate(&self, _batch: &RecordBatch) -> Result<ArrayRef> {
+ Err(DataFusionError::NotImplemented(format!(
+ "Group based evaluation for {} is not yet implemented",
+ self.name()
+ )))
+ }
+
+ fn row_based_evaluate(&self, _batch: &RecordBatch) -> Result<ArrayRef> {
+ Err(DataFusionError::NotImplemented(format!(
+ "Row based evaluation for {} is not yet implemented",
+ self.name()
+ )))
}
}
@@ -202,9 +249,55 @@ impl WindowExpr for AggregateWindowExpr {
self.aggregate.expressions()
}
- fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
- let accumulator = self.aggregate.create_accumulator()?;
- Ok(Box::new(AggregateWindowAccumulator { accumulator }))
+ fn partition_by(&self) -> &[Arc<dyn PhysicalExpr>] {
+ &self.partition_by
+ }
+
+ fn order_by(&self) -> &[PhysicalSortExpr] {
+ &self.order_by
+ }
+
+ /// evaluate the window function values against the batch
+ fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
+ // FIXME, for now we assume all the rows belong to the same partition, which will not be the
+ // case when partition_by is supported, in which case we'll parallelize the calls.
+ // See https://github.com/apache/arrow-datafusion/issues/299
+ match self.evaluation_mode() {
+ WindowFrameUnits::Range => self.peer_based_evaluate(batch),
+ WindowFrameUnits::Rows => self.row_based_evaluate(batch),
+ WindowFrameUnits::Groups => self.group_based_evaluate(batch),
+ }
+ }
+}
+
+/// Aggregate window accumulator utilizes the accumulator from aggregation and do a accumulative sum
+/// across evaluation arguments based on peer equivalences.
+#[derive(Debug)]
+struct AggregateWindowAccumulator {
+ accumulator: Box<dyn Accumulator>,
+}
+
+impl AggregateWindowAccumulator {
+ /// scan one peer group of values (as arguments to window function) given by the value_range
+ /// and return evaluation result that are of the same number of rows.
+ fn scan_peers(
+ &mut self,
+ values: &[ArrayRef],
+ value_range: &Range<usize>,
+ ) -> Result<ArrayRef> {
+ if value_range.is_empty() {
+ return Err(DataFusionError::Internal(
+ "Value range cannot be empty".to_owned(),
+ ));
+ }
+ let len = value_range.end - value_range.start;
+ let values = values
+ .iter()
+ .map(|v| v.slice(value_range.start, len))
+ .collect::<Vec<_>>();
+ self.accumulator.update_batch(&values)?;
+ let value = self.accumulator.evaluate()?;
+ Ok(value.to_array_of_size(len))
}
}
@@ -329,106 +422,17 @@ pin_project! {
}
}
-type WindowAccumulatorItem = Box<dyn WindowAccumulator>;
-
-fn window_expressions(
- window_expr: &[Arc<dyn WindowExpr>],
-) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
- Ok(window_expr
- .iter()
- .map(|expr| expr.expressions())
- .collect::<Vec<_>>())
-}
-
-fn window_aggregate_batch(
- batch: &RecordBatch,
- window_accumulators: &mut [WindowAccumulatorItem],
- expressions: &[Vec<Arc<dyn PhysicalExpr>>],
-) -> Result<Vec<Option<ArrayRef>>> {
- window_accumulators
- .iter_mut()
- .zip(expressions)
- .map(|(window_acc, expr)| {
- let values = &expr
- .iter()
- .map(|e| e.evaluate(batch))
- .map(|r| r.map(|v| v.into_array(batch.num_rows())))
- .collect::<Result<Vec<_>>>()?;
- window_acc.scan_batch(batch.num_rows(), values)
- })
- .collect::<Result<Vec<_>>>()
-}
-
-/// returns a vector of ArrayRefs, where each entry corresponds to one window expr
-fn finalize_window_aggregation(
- window_accumulators: &[WindowAccumulatorItem],
-) -> Result<Vec<Option<ScalarValue>>> {
- window_accumulators
- .iter()
- .map(|window_accumulator| window_accumulator.evaluate())
- .collect::<Result<Vec<_>>>()
-}
-
-fn create_window_accumulators(
- window_expr: &[Arc<dyn WindowExpr>],
-) -> Result<Vec<WindowAccumulatorItem>> {
- window_expr
- .iter()
- .map(|expr| expr.create_accumulator())
- .collect::<Result<Vec<_>>>()
-}
-
/// Compute the window aggregate columns
-///
-/// 1. get a list of window accumulators
-/// 2. evaluate the args
-/// 3. scan args with window functions
-/// 4. concat with final aggregations
-///
-/// FIXME so far this fn does not support:
-/// 1. partition by
-/// 2. order by
-/// 3. window frame
-///
-/// which will require further work:
-/// 1. inter-partition order by using vec partition-point (https://github.com/apache/arrow-datafusion/issues/360)
-/// 2. inter-partition parallelism using one-shot channel (https://github.com/apache/arrow-datafusion/issues/299)
-/// 3. convert aggregation based window functions to be self-contain so that: (https://github.com/apache/arrow-datafusion/issues/361)
-/// a. some can be grow-only window-accumulating
-/// b. some can be grow-and-shrink window-accumulating
-/// c. some can be based on segment tree
fn compute_window_aggregates(
window_expr: Vec<Arc<dyn WindowExpr>>,
batch: &RecordBatch,
) -> Result<Vec<ArrayRef>> {
- let mut window_accumulators = create_window_accumulators(&window_expr)?;
- let expressions = Arc::new(window_expressions(&window_expr)?);
- let num_rows = batch.num_rows();
- let window_aggregates =
- window_aggregate_batch(batch, &mut window_accumulators, &expressions)?;
- let final_aggregates = finalize_window_aggregation(&window_accumulators)?;
-
- // both must equal to window_expr.len()
- if window_aggregates.len() != final_aggregates.len() {
- return Err(DataFusionError::Internal(
- "Impossibly got len mismatch".to_owned(),
- ));
- }
-
- window_aggregates
+ // FIXME, for now we assume all the rows belong to the same partition, which will not be the
+ // case when partition_by is supported, in which case we'll parallelize the calls.
+ // See https://github.com/apache/arrow-datafusion/issues/299
+ window_expr
.iter()
- .zip(final_aggregates)
- .map(|(wa, fa)| {
- Ok(match (wa, fa) {
- (None, Some(fa)) => fa.to_array_of_size(num_rows),
- (Some(wa), None) if wa.len() == num_rows => wa.clone(),
- _ => {
- return Err(DataFusionError::Execution(
- "Invalid window function behavior".to_owned(),
- ))
- }
- })
- })
+ .map(|window_expr| window_expr.evaluate(batch))
.collect()
}
diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs
index ac7deee..933bb8c 100644
--- a/datafusion/src/scalar.rs
+++ b/datafusion/src/scalar.rs
@@ -771,7 +771,7 @@ impl ScalarValue {
let dict_array = array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
// look up the index in the values dictionary
- let keys_col = dict_array.keys_array();
+ let keys_col = dict_array.keys();
let values_index = keys_col.value(index).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
"Can not convert index to usize in dictionary of type creating group by value {:?}",
diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs
index e860bd7..4c1d861 100644
--- a/datafusion/src/sql/planner.rs
+++ b/datafusion/src/sql/planner.rs
@@ -703,7 +703,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let mut plan = input;
let mut groups = group_window_expr_by_sort_keys(&window_exprs)?;
// sort by sort_key len descending, so that more deeply sorted plans gets nested further
- // down as children; to further minic the behavior of PostgreSQL, we want stable sort
+ // down as children; to further mimic the behavior of PostgreSQL, we want stable sort
// and a reverse so that tieing sort keys are reversed in order; note that by this rule
// if there's an empty over, it'll be at the top level
groups.sort_by(|(key_a, _), (key_b, _)| key_a.len().cmp(&key_b.len()));
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index d9d7764..21da793 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -802,25 +802,142 @@ async fn csv_query_window_with_empty_over() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx)?;
let sql = "select \
- c2, \
- sum(c3) over (), \
- avg(c3) over (), \
- count(c3) over (), \
- max(c3) over (), \
- min(c3) over (), \
- first_value(c3) over (), \
- last_value(c3) over (), \
- nth_value(c3, 2) over ()
+ c9, \
+ count(c5) over (), \
+ max(c5) over (), \
+ min(c5) over (), \
+ first_value(c5) over (), \
+ last_value(c5) over (), \
+ nth_value(c5, 2) over () \
from aggregate_test_100 \
- order by c2
+ order by c9 \
limit 5";
let actual = execute(&mut ctx, sql).await;
let expected = vec![
- vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"],
- vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"],
- vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"],
- vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"],
- vec!["1", "781", "7.81", "100", "125", "-117", "1", "30", "-40"],
+ vec![
+ "28774375",
+ "100",
+ "2143473091",
+ "-2141999138",
+ "2033001162",
+ "61035129",
+ "706441268",
+ ],
+ vec![
+ "63044568",
+ "100",
+ "2143473091",
+ "-2141999138",
+ "2033001162",
+ "61035129",
+ "706441268",
+ ],
+ vec![
+ "141047417",
+ "100",
+ "2143473091",
+ "-2141999138",
+ "2033001162",
+ "61035129",
+ "706441268",
+ ],
+ vec![
+ "141680161",
+ "100",
+ "2143473091",
+ "-2141999138",
+ "2033001162",
+ "61035129",
+ "706441268",
+ ],
+ vec![
+ "145294611",
+ "100",
+ "2143473091",
+ "-2141999138",
+ "2033001162",
+ "61035129",
+ "706441268",
+ ],
+ ];
+ assert_eq!(expected, actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_window_with_order_by() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx)?;
+ let sql = "select \
+ c9, \
+ sum(c5) over (order by c9), \
+ avg(c5) over (order by c9), \
+ count(c5) over (order by c9), \
+ max(c5) over (order by c9), \
+ min(c5) over (order by c9), \
+ first_value(c5) over (order by c9), \
+ last_value(c5) over (order by c9), \
+ nth_value(c5, 2) over (order by c9) \
+ from aggregate_test_100 \
+ order by c9 \
+ limit 5";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![
+ vec![
+ "28774375",
+ "61035129",
+ "61035129",
+ "1",
+ "61035129",
+ "61035129",
+ "61035129",
+ "2025611582",
+ "-108973366",
+ ],
+ vec![
+ "63044568",
+ "-47938237",
+ "-23969118.5",
+ "2",
+ "61035129",
+ "-108973366",
+ "61035129",
+ "2025611582",
+ "-108973366",
+ ],
+ vec![
+ "141047417",
+ "575165281",
+ "191721760.33333334",
+ "3",
+ "623103518",
+ "-108973366",
+ "61035129",
+ "2025611582",
+ "-108973366",
+ ],
+ vec![
+ "141680161",
+ "-1352462829",
+ "-338115707.25",
+ "4",
+ "623103518",
+ "-1927628110",
+ "61035129",
+ "2025611582",
+ "-108973366",
+ ],
+ vec![
+ "145294611",
+ "-3251637940",
+ "-650327588",
+ "5",
+ "623103518",
+ "-1927628110",
+ "61035129",
+ "2025611582",
+ "-108973366",
+ ],
];
assert_eq!(expected, actual);
Ok(())
diff --git a/integration-tests/sqls/simple_window_ordered_aggregation.sql b/integration-tests/sqls/simple_window_ordered_aggregation.sql
new file mode 100644
index 0000000..d9f467b
--- /dev/null
+++ b/integration-tests/sqls/simple_window_ordered_aggregation.sql
@@ -0,0 +1,26 @@
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements. See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership. The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License. You may obtain a copy of the License at
+
+-- http://www.apache.org/licenses/LICENSE-2.0
+
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language gOVERning permissions and
+-- limitations under the License.
+
+SELECT
+ c9,
+ row_number() OVER (ORDER BY c2, c9) AS row_number,
+ count(c3) OVER (ORDER BY c9) AS count_c3,
+ avg(c3) OVER (ORDER BY c2) AS avg_c3_by_c2,
+ sum(c3) OVER (ORDER BY c2) AS sum_c3_by_c2,
+ max(c3) OVER (ORDER BY c2) AS max_c3_by_c2,
+ min(c3) OVER (ORDER BY c2) AS min_c3_by_c2
+FROM test
+ORDER BY row_number;
diff --git a/integration-tests/test_psql_parity.py b/integration-tests/test_psql_parity.py
index 51861c5..4e0878c 100644
--- a/integration-tests/test_psql_parity.py
+++ b/integration-tests/test_psql_parity.py
@@ -74,7 +74,7 @@ class PsqlParityTest(unittest.TestCase):
def test_parity(self):
root = Path(os.path.dirname(__file__)) / "sqls"
files = set(root.glob("*.sql"))
- self.assertEqual(len(files), 6, msg="tests are missed")
+ self.assertEqual(len(files), 7, msg="tests are missed")
for fname in files:
with self.subTest(fname=fname):
datafusion_output = pd.read_csv(