You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/05/26 20:03:26 UTC
[arrow-datafusion] branch master updated: add window expression
stream, delegated window aggregation to aggregate functions,
and implement `row_number` (#375)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 4b1e9e6 add window expression stream, delegated window aggregation to aggregate functions, and implement `row_number` (#375)
4b1e9e6 is described below
commit 4b1e9e6fae0e200debda215f6ad78c654c37c1a8
Author: Jiayu Liu <Ji...@users.noreply.github.com>
AuthorDate: Thu May 27 04:03:20 2021 +0800
add window expression stream, delegated window aggregation to aggregate functions, and implement `row_number` (#375)
* Squashed commit of the following:
commit 7fb3640e733bfbbdbf18d58000896f378ba9644c
Author: Jiayu Liu <ji...@airbnb.com>
Date: Fri May 21 16:38:25 2021 +0800
row number done
commit 17239267cd2fbcbb676d5731beeffd0321bbd3ba
Author: Jiayu Liu <ji...@airbnb.com>
Date: Fri May 21 16:05:50 2021 +0800
add row number
commit bf5b8a56f6f33d8eedf3e3009e7fcdb3c388ea5b
Author: Jiayu Liu <ji...@airbnb.com>
Date: Fri May 21 15:04:49 2021 +0800
save
commit d2ce852ead5d8ae3d15962b4dd3062e24bce51de
Author: Jiayu Liu <ji...@airbnb.com>
Date: Fri May 21 14:53:05 2021 +0800
add streams
commit 0a861a76bde0bb43e5561f1cf1ef14fd64e0c08b
Author: Jiayu Liu <ji...@airbnb.com>
Date: Thu May 20 22:28:34 2021 +0800
save stream
commit a9121af7e2e9104d0e4b6ca3ef4f484aaf8baf42
Author: Jiayu Liu <ji...@airbnb.com>
Date: Thu May 20 22:01:51 2021 +0800
update unit test
commit 2af2a270262ff1bc759af39153d7cd681c32dc0a
Author: Jiayu Liu <ji...@airbnb.com>
Date: Fri May 21 14:25:12 2021 +0800
fix unit test
commit bb57c762b0a1fabc35e207e681bca2bfff7fcf01
Author: Jiayu Liu <ji...@airbnb.com>
Date: Fri May 21 14:23:34 2021 +0800
use upper case
commit 5d96e525f587fbfaf3e5e9762c9bb10315fcbc3a
Author: Jiayu Liu <ji...@airbnb.com>
Date: Fri May 21 14:16:16 2021 +0800
fix unit test
commit 1ecae8f6cbc6c1898ccf0b38b1e596b6c2e9bb46
Author: Jiayu Liu <ji...@airbnb.com>
Date: Fri May 21 12:27:26 2021 +0800
fix unit test
commit bc2271d58fd4a9a9cc96126f8abcd6e8f10272ca
Author: Jiayu Liu <ji...@airbnb.com>
Date: Fri May 21 10:04:29 2021 +0800
fix error
commit 880b94f6e27df61b4d3877366f71a51b9b2f5d5d
Author: Jiayu Liu <ji...@airbnb.com>
Date: Fri May 21 08:24:00 2021 +0800
fix unit test
commit 4e792e123a33fd0dcb5f701c679566b55589b0c0
Author: Jiayu Liu <ji...@airbnb.com>
Date: Fri May 21 08:05:17 2021 +0800
fix test
commit c36c04abf06c74d016597983bf3d3a2a5b5cbdd5
Author: Jiayu Liu <ji...@airbnb.com>
Date: Fri May 21 00:07:54 2021 +0800
add more tests
commit f5e64de7192a1916df78a4c2fbab7d471c906720
Author: Jiayu Liu <ji...@airbnb.com>
Date: Thu May 20 23:41:36 2021 +0800
update
commit a1eae864926a6acfeeebe995a12de4ad725ea869
Author: Jiayu Liu <ji...@airbnb.com>
Date: Thu May 20 23:36:15 2021 +0800
enrich unit test
commit 0d2a214131fe69e19e22144c68fbb992228db6b3
Author: Jiayu Liu <ji...@airbnb.com>
Date: Thu May 20 23:25:43 2021 +0800
adding filter by todo
commit 8b486d53b09ff1c7a6b9cf4687796ba1c13d6160
Author: Jiayu Liu <ji...@airbnb.com>
Date: Thu May 20 23:17:22 2021 +0800
adding more built-in functions
commit abf08cd137a80c1381af7de9ae2b3dab05cb4512
Author: Jiayu Liu <Ji...@users.noreply.github.com>
Date: Thu May 20 22:36:27 2021 +0800
Update datafusion/src/physical_plan/window_functions.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
commit 0cbca53dac642233520f7d32289b1dfad77b882e
Author: Jiayu Liu <Ji...@users.noreply.github.com>
Date: Thu May 20 22:34:57 2021 +0800
Update datafusion/src/physical_plan/window_functions.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
commit 831c069f02236a953653b8f1ca25124e393ce20b
Author: Jiayu Liu <Ji...@users.noreply.github.com>
Date: Thu May 20 22:34:04 2021 +0800
Update datafusion/src/logical_plan/builder.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
commit f70c739fd40e30c4b476253e58b24b9297b42859
Author: Jiayu Liu <Ji...@users.noreply.github.com>
Date: Thu May 20 22:33:04 2021 +0800
Update datafusion/src/logical_plan/builder.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
commit 3ee87aa3477c160f17a86628d71a353e03d736b3
Author: Jiayu Liu <ji...@airbnb.com>
Date: Wed May 19 22:55:08 2021 +0800
fix unit test
commit 5c4d92dc9f570ba6919d84cb8ac70a736d73f40f
Author: Jiayu Liu <ji...@airbnb.com>
Date: Wed May 19 22:48:26 2021 +0800
fix clippy
commit a0b7526c413abbdd4aadab4af8ca9ad8f323f03b
Author: Jiayu Liu <ji...@airbnb.com>
Date: Wed May 19 22:46:38 2021 +0800
fix unused imports
commit 1d3b076acc1c0f248a19c6149c0634e63a5b836e
Author: Jiayu Liu <ji...@airbnb.com>
Date: Thu May 13 18:51:14 2021 +0800
add window expr
* fix unit test
---
datafusion/src/execution/context.rs | 29 ++
datafusion/src/physical_plan/expressions/mod.rs | 2 +
.../src/physical_plan/expressions/row_number.rs | 174 ++++++++++
datafusion/src/physical_plan/hash_aggregate.rs | 7 +-
datafusion/src/physical_plan/mod.rs | 81 ++++-
datafusion/src/physical_plan/planner.rs | 4 +-
datafusion/src/physical_plan/sort.rs | 1 +
datafusion/src/physical_plan/window_functions.rs | 107 ++++--
datafusion/src/physical_plan/windows.rs | 365 ++++++++++++++++++++-
datafusion/tests/sql.rs | 39 ++-
parquet-testing | 2 +-
11 files changed, 736 insertions(+), 75 deletions(-)
diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs
index 272e75a..cfd3b71 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -1269,6 +1269,35 @@ mod tests {
}
#[tokio::test]
+ async fn window() -> Result<()> {
+ let results = execute(
+ "SELECT c1, c2, SUM(c2) OVER (), COUNT(c2) OVER (), MAX(c2) OVER (), MIN(c2) OVER (), AVG(c2) OVER () FROM test ORDER BY c1, c2 LIMIT 5",
+ 4,
+ )
+ .await?;
+ // result in one batch, although e.g. having 2 batches do not change
+ // result semantics, having a len=1 assertion upfront keeps surprises
+ // at bay
+ assert_eq!(results.len(), 1);
+
+ let expected = vec![
+ "+----+----+---------+-----------+---------+---------+---------+",
+ "| c1 | c2 | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |",
+ "+----+----+---------+-----------+---------+---------+---------+",
+ "| 0 | 1 | 220 | 40 | 10 | 1 | 5.5 |",
+ "| 0 | 2 | 220 | 40 | 10 | 1 | 5.5 |",
+ "| 0 | 3 | 220 | 40 | 10 | 1 | 5.5 |",
+ "| 0 | 4 | 220 | 40 | 10 | 1 | 5.5 |",
+ "| 0 | 5 | 220 | 40 | 10 | 1 | 5.5 |",
+ "+----+----+---------+-----------+---------+---------+---------+",
+ ];
+
+ // window function shall respect ordering
+ assert_batches_eq!(expected, &results);
+ Ok(())
+ }
+
+ #[tokio::test]
async fn aggregate() -> Result<()> {
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?;
assert_eq!(results.len(), 1);
diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs
index 4d57c39..803870f 100644
--- a/datafusion/src/physical_plan/expressions/mod.rs
+++ b/datafusion/src/physical_plan/expressions/mod.rs
@@ -41,6 +41,7 @@ mod min_max;
mod negative;
mod not;
mod nullif;
+mod row_number;
mod sum;
mod try_cast;
@@ -58,6 +59,7 @@ pub use min_max::{Max, Min};
pub use negative::{negative, NegativeExpr};
pub use not::{not, NotExpr};
pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES};
+pub use row_number::RowNumber;
pub use sum::{sum_return_type, Sum};
pub use try_cast::{try_cast, TryCastExpr};
/// returns the name of the state
diff --git a/datafusion/src/physical_plan/expressions/row_number.rs b/datafusion/src/physical_plan/expressions/row_number.rs
new file mode 100644
index 0000000..f399995
--- /dev/null
+++ b/datafusion/src/physical_plan/expressions/row_number.rs
@@ -0,0 +1,174 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines physical expression for `row_number` that can evaluated at runtime during query execution
+
+use crate::error::Result;
+use crate::physical_plan::{
+ window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator,
+};
+use crate::scalar::ScalarValue;
+use arrow::array::{ArrayRef, UInt64Array};
+use arrow::datatypes::{DataType, Field};
+use std::any::Any;
+use std::sync::Arc;
+
+/// row_number expression
+#[derive(Debug)]
+pub struct RowNumber {
+ name: String,
+}
+
+impl RowNumber {
+ /// Create a new ROW_NUMBER function
+ pub fn new(name: String) -> Self {
+ Self { name }
+ }
+}
+
+impl BuiltInWindowFunctionExpr for RowNumber {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn field(&self) -> Result<Field> {
+ let nullable = false;
+ let data_type = DataType::UInt64;
+ Ok(Field::new(&self.name(), data_type, nullable))
+ }
+
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![]
+ }
+
+ fn name(&self) -> &str {
+ self.name.as_str()
+ }
+
+ fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
+ Ok(Box::new(RowNumberAccumulator::new()))
+ }
+}
+
+#[derive(Debug)]
+struct RowNumberAccumulator {
+ row_number: u64,
+}
+
+impl RowNumberAccumulator {
+ /// new row_number accumulator
+ pub fn new() -> Self {
+ // row number is 1 based
+ Self { row_number: 1 }
+ }
+}
+
+impl WindowAccumulator for RowNumberAccumulator {
+ fn scan(&mut self, _values: &[ScalarValue]) -> Result<Option<ScalarValue>> {
+ let result = Some(ScalarValue::UInt64(Some(self.row_number)));
+ self.row_number += 1;
+ Ok(result)
+ }
+
+ fn scan_batch(
+ &mut self,
+ num_rows: usize,
+ _values: &[ArrayRef],
+ ) -> Result<Option<ArrayRef>> {
+ let new_row_number = self.row_number + (num_rows as u64);
+ // TODO: probably would be nice to have a (optimized) kernel for this at some point to
+ // generate an array like this.
+ let result = UInt64Array::from_iter_values(self.row_number..new_row_number);
+ self.row_number = new_row_number;
+ Ok(Some(Arc::new(result)))
+ }
+
+ fn evaluate(&self) -> Result<Option<ScalarValue>> {
+ Ok(None)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::error::Result;
+ use arrow::record_batch::RecordBatch;
+ use arrow::{array::*, datatypes::*};
+
+ #[test]
+ fn row_number_all_null() -> Result<()> {
+ let arr: ArrayRef = Arc::new(BooleanArray::from(vec![
+ None, None, None, None, None, None, None, None,
+ ]));
+ let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]);
+ let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;
+
+ let row_number = Arc::new(RowNumber::new("row_number".to_owned()));
+
+ let mut acc = row_number.create_accumulator()?;
+ let expr = row_number.expressions();
+ let values = expr
+ .iter()
+ .map(|e| e.evaluate(&batch))
+ .map(|r| r.map(|v| v.into_array(batch.num_rows())))
+ .collect::<Result<Vec<_>>>()?;
+
+ let result = acc.scan_batch(batch.num_rows(), &values)?;
+ assert_eq!(true, result.is_some());
+
+ let result = result.unwrap();
+ let result = result.as_any().downcast_ref::<UInt64Array>().unwrap();
+ let result = result.values();
+ assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);
+
+ let result = acc.evaluate()?;
+ assert_eq!(false, result.is_some());
+ Ok(())
+ }
+
+ #[test]
+ fn row_number_all_values() -> Result<()> {
+ let arr: ArrayRef = Arc::new(BooleanArray::from(vec![
+ true, false, true, false, false, true, false, true,
+ ]));
+ let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]);
+ let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;
+
+ let row_number = Arc::new(RowNumber::new("row_number".to_owned()));
+
+ let mut acc = row_number.create_accumulator()?;
+ let expr = row_number.expressions();
+ let values = expr
+ .iter()
+ .map(|e| e.evaluate(&batch))
+ .map(|r| r.map(|v| v.into_array(batch.num_rows())))
+ .collect::<Result<Vec<_>>>()?;
+
+ let result = acc.scan_batch(batch.num_rows(), &values)?;
+ assert_eq!(true, result.is_some());
+
+ let result = result.unwrap();
+ let result = result.as_any().downcast_ref::<UInt64Array>().unwrap();
+ let result = result.values();
+ assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);
+
+ let result = acc.evaluate()?;
+ assert_eq!(false, result.is_some());
+ Ok(())
+ }
+}
diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs
index c9d2686..5008f49 100644
--- a/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/datafusion/src/physical_plan/hash_aggregate.rs
@@ -712,7 +712,7 @@ impl GroupedHashAggregateStream {
tx.send(result)
});
- GroupedHashAggregateStream {
+ Self {
schema,
output: rx,
finished: false,
@@ -825,7 +825,8 @@ fn aggregate_expressions(
}
pin_project! {
- struct HashAggregateStream {
+ /// stream struct for hash aggregation
+ pub struct HashAggregateStream {
schema: SchemaRef,
#[pin]
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
@@ -878,7 +879,7 @@ impl HashAggregateStream {
tx.send(result)
});
- HashAggregateStream {
+ Self {
schema,
output: rx,
finished: false,
diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs
index c053229..4f90a8c 100644
--- a/datafusion/src/physical_plan/mod.rs
+++ b/datafusion/src/physical_plan/mod.rs
@@ -17,22 +17,23 @@
//! Traits for physical query plan, supporting parallel execution for partitioned relations.
-use std::fmt::{self, Debug, Display};
-use std::sync::atomic::{AtomicUsize, Ordering};
-use std::sync::Arc;
-use std::{any::Any, pin::Pin};
-
use crate::execution::context::ExecutionContextState;
use crate::logical_plan::LogicalPlan;
-use crate::{error::Result, scalar::ScalarValue};
+use crate::{
+ error::{DataFusionError, Result},
+ scalar::ScalarValue,
+};
use arrow::datatypes::{DataType, Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, datatypes::Field};
-
use async_trait::async_trait;
pub use display::DisplayFormatType;
use futures::stream::Stream;
+use std::fmt::{self, Debug, Display};
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::Arc;
+use std::{any::Any, pin::Pin};
use self::{display::DisplayableExecutionPlan, merge::MergeExec};
use hashbrown::HashMap;
@@ -457,10 +458,22 @@ pub trait WindowExpr: Send + Sync + Debug {
fn name(&self) -> &str {
"WindowExpr: default name"
}
+
+ /// the accumulator used to accumulate values from the expressions.
+ /// the accumulator expects the same number of arguments as `expressions` and must
+ /// return states with the same description as `state_fields`
+ fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>>;
+
+ /// expressions that are passed to the WindowAccumulator.
+ /// Functions which take a single input argument, such as `sum`, return a single [`Expr`],
+ /// others (e.g. `cov`) return many.
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
}
/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and
-/// generically accumulates values. An accumulator knows how to:
+/// generically accumulates values.
+///
+/// An accumulator knows how to:
/// * update its state from inputs via `update`
/// * convert its internal state to a vector of scalar values
/// * update its state from multiple accumulators' states via `merge`
@@ -509,6 +522,58 @@ pub trait Accumulator: Send + Sync + Debug {
fn evaluate(&self) -> Result<ScalarValue>;
}
+/// A window accumulator represents a stateful object that lives throughout the evaluation of multiple
+/// rows and generically accumulates values.
+///
+/// An accumulator knows how to:
+/// * update its state from inputs via `update`
+/// * convert its internal state to a vector of scalar values
+/// * update its state from multiple accumulators' states via `merge`
+/// * compute the final value from its internal state via `evaluate`
+pub trait WindowAccumulator: Send + Sync + Debug {
+ /// scans the accumulator's state from a vector of scalars, similar to Accumulator it also
+ /// optionally generates values.
+ fn scan(&mut self, values: &[ScalarValue]) -> Result<Option<ScalarValue>>;
+
+ /// scans the accumulator's state from a vector of arrays.
+ fn scan_batch(
+ &mut self,
+ num_rows: usize,
+ values: &[ArrayRef],
+ ) -> Result<Option<ArrayRef>> {
+ if values.is_empty() {
+ return Ok(None);
+ };
+ // transpose columnar to row based so that we can apply window
+ let result = (0..num_rows)
+ .map(|index| {
+ let v = values
+ .iter()
+ .map(|array| ScalarValue::try_from_array(array, index))
+ .collect::<Result<Vec<_>>>()?;
+ self.scan(&v)
+ })
+ .collect::<Result<Vec<Option<ScalarValue>>>>()?
+ .into_iter()
+ .collect::<Option<Vec<ScalarValue>>>();
+
+ Ok(match result {
+ Some(arr) if num_rows == arr.len() => Some(ScalarValue::iter_to_array(&arr)?),
+ None => None,
+ Some(arr) => {
+ return Err(DataFusionError::Internal(format!(
+ "expect scan batch to return {:?} rows, but got {:?}",
+ num_rows,
+ arr.len()
+ )))
+ }
+ })
+ }
+
+ /// returns its value based on its current state.
+ fn evaluate(&self) -> Result<Option<ScalarValue>>;
+}
+
pub mod aggregates;
pub mod array_expressions;
pub mod coalesce_batches;
diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs
index 018925d..7ddfaf8 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -147,8 +147,10 @@ impl DefaultPhysicalPlanner {
// Initially need to perform the aggregate and then merge the partitions
let input_exec = self.create_initial_plan(input, ctx_state)?;
let input_schema = input_exec.schema();
- let physical_input_schema = input_exec.as_ref().schema();
+
let logical_input_schema = input.as_ref().schema();
+ let physical_input_schema = input_exec.as_ref().schema();
+
let window_expr = window_expr
.iter()
.map(|e| {
diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs
index 7cd4d9d..c5b838c 100644
--- a/datafusion/src/physical_plan/sort.rs
+++ b/datafusion/src/physical_plan/sort.rs
@@ -250,6 +250,7 @@ fn sort_batches(
}
pin_project! {
+ /// stream for sort plan
struct SortStream {
#[pin]
output: futures::channel::oneshot::Receiver<ArrowResult<Option<RecordBatch>>>,
diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs
index 65d5373..e6afcaa 100644
--- a/datafusion/src/physical_plan/window_functions.rs
+++ b/datafusion/src/physical_plan/window_functions.rs
@@ -20,12 +20,15 @@
//!
//! see also https://www.postgresql.org/docs/current/functions-window.html
+use crate::arrow::datatypes::Field;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{
aggregates, aggregates::AggregateFunction, functions::Signature,
- type_coercion::data_types,
+ type_coercion::data_types, PhysicalExpr, WindowAccumulator,
};
use arrow::datatypes::DataType;
+use std::any::Any;
+use std::sync::Arc;
use std::{fmt, str::FromStr};
/// WindowFunction
@@ -143,52 +146,92 @@ impl FromStr for BuiltInWindowFunction {
/// Returns the datatype of the window function
pub fn return_type(fun: &WindowFunction, arg_types: &[DataType]) -> Result<DataType> {
+ match fun {
+ WindowFunction::AggregateFunction(fun) => aggregates::return_type(fun, arg_types),
+ WindowFunction::BuiltInWindowFunction(fun) => {
+ return_type_for_built_in(fun, arg_types)
+ }
+ }
+}
+
+/// Returns the datatype of the built-in window function
+pub(super) fn return_type_for_built_in(
+ fun: &BuiltInWindowFunction,
+ arg_types: &[DataType],
+) -> Result<DataType> {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.
// verify that this is a valid set of data types for this function
- data_types(arg_types, &signature(fun))?;
+ data_types(arg_types, &signature_for_built_in(fun))?;
match fun {
- WindowFunction::AggregateFunction(fun) => aggregates::return_type(fun, arg_types),
- WindowFunction::BuiltInWindowFunction(fun) => match fun {
- BuiltInWindowFunction::RowNumber
- | BuiltInWindowFunction::Rank
- | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64),
- BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => {
- Ok(DataType::Float64)
- }
- BuiltInWindowFunction::Ntile => Ok(DataType::UInt32),
- BuiltInWindowFunction::Lag
- | BuiltInWindowFunction::Lead
- | BuiltInWindowFunction::FirstValue
- | BuiltInWindowFunction::LastValue
- | BuiltInWindowFunction::NthValue => Ok(arg_types[0].clone()),
- },
+ BuiltInWindowFunction::RowNumber
+ | BuiltInWindowFunction::Rank
+ | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64),
+ BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => {
+ Ok(DataType::Float64)
+ }
+ BuiltInWindowFunction::Ntile => Ok(DataType::UInt32),
+ BuiltInWindowFunction::Lag
+ | BuiltInWindowFunction::Lead
+ | BuiltInWindowFunction::FirstValue
+ | BuiltInWindowFunction::LastValue
+ | BuiltInWindowFunction::NthValue => Ok(arg_types[0].clone()),
}
}
/// the signatures supported by the function `fun`.
-fn signature(fun: &WindowFunction) -> Signature {
- // note: the physical expression must accept the type returned by this function or the execution panics.
+pub fn signature(fun: &WindowFunction) -> Signature {
match fun {
WindowFunction::AggregateFunction(fun) => aggregates::signature(fun),
- WindowFunction::BuiltInWindowFunction(fun) => match fun {
- BuiltInWindowFunction::RowNumber
- | BuiltInWindowFunction::Rank
- | BuiltInWindowFunction::DenseRank
- | BuiltInWindowFunction::PercentRank
- | BuiltInWindowFunction::CumeDist => Signature::Any(0),
- BuiltInWindowFunction::Lag
- | BuiltInWindowFunction::Lead
- | BuiltInWindowFunction::FirstValue
- | BuiltInWindowFunction::LastValue => Signature::Any(1),
- BuiltInWindowFunction::Ntile => Signature::Exact(vec![DataType::UInt64]),
- BuiltInWindowFunction::NthValue => Signature::Any(2),
- },
+ WindowFunction::BuiltInWindowFunction(fun) => signature_for_built_in(fun),
+ }
+}
+
+/// the signatures supported by the built-in window function `fun`.
+pub(super) fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature {
+ // note: the physical expression must accept the type returned by this function or the execution panics.
+ match fun {
+ BuiltInWindowFunction::RowNumber
+ | BuiltInWindowFunction::Rank
+ | BuiltInWindowFunction::DenseRank
+ | BuiltInWindowFunction::PercentRank
+ | BuiltInWindowFunction::CumeDist => Signature::Any(0),
+ BuiltInWindowFunction::Lag
+ | BuiltInWindowFunction::Lead
+ | BuiltInWindowFunction::FirstValue
+ | BuiltInWindowFunction::LastValue => Signature::Any(1),
+ BuiltInWindowFunction::Ntile => Signature::Exact(vec![DataType::UInt64]),
+ BuiltInWindowFunction::NthValue => Signature::Any(2),
}
}
+/// A window expression that is a built-in window function
+pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug {
+ /// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be
+ /// downcast to a specific implementation.
+ fn as_any(&self) -> &dyn Any;
+
+ /// the field of the final result of this aggregation.
+ fn field(&self) -> Result<Field>;
+
+ /// expressions that are passed to the Accumulator.
+ /// Single-column aggregations such as `sum` return a single value, others (e.g. `cov`) return many.
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
+
+ /// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default
+ /// implementation returns placeholder text.
+ fn name(&self) -> &str {
+ "BuiltInWindowFunctionExpr: default name"
+ }
+
+ /// the accumulator used to accumulate values from the expressions.
+ /// the accumulator expects the same number of arguments as `expressions` and must
+ /// return states with the same description as `state_fields`
+ fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>>;
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs
index bdd25d6..8ced3ae 100644
--- a/datafusion/src/physical_plan/windows.rs
+++ b/datafusion/src/physical_plan/windows.rs
@@ -19,13 +19,30 @@
use crate::error::{DataFusionError, Result};
use crate::physical_plan::{
- aggregates, window_functions::WindowFunction, AggregateExpr, Distribution,
- ExecutionPlan, Partitioning, PhysicalExpr, SendableRecordBatchStream, WindowExpr,
+ aggregates,
+ expressions::RowNumber,
+ window_functions::BuiltInWindowFunctionExpr,
+ window_functions::{BuiltInWindowFunction, WindowFunction},
+ Accumulator, AggregateExpr, Distribution, ExecutionPlan, Partitioning, PhysicalExpr,
+ RecordBatchStream, SendableRecordBatchStream, WindowAccumulator, WindowExpr,
+};
+use crate::scalar::ScalarValue;
+use arrow::compute::concat;
+use arrow::{
+ array::{Array, ArrayRef},
+ datatypes::{Field, Schema, SchemaRef},
+ error::{ArrowError, Result as ArrowResult},
+ record_batch::RecordBatch,
};
-use arrow::datatypes::{Field, Schema, SchemaRef};
use async_trait::async_trait;
+use futures::stream::{Stream, StreamExt};
+use futures::Future;
+use pin_project_lite::pin_project;
use std::any::Any;
+use std::iter;
+use std::pin::Pin;
use std::sync::Arc;
+use std::task::{Context, Poll};
/// Window execution plan
#[derive(Debug)]
@@ -57,18 +74,55 @@ pub fn create_window_expr(
name,
)?,
})),
- WindowFunction::BuiltInWindowFunction(fun) => {
- Err(DataFusionError::NotImplemented(format!(
- "window function with {:?} not implemented",
- fun
- )))
- }
+ WindowFunction::BuiltInWindowFunction(fun) => Ok(Arc::new(BuiltInWindowExpr {
+ window: create_built_in_window_expr(fun, args, input_schema, name)?,
+ })),
+ }
+}
+
+fn create_built_in_window_expr(
+ fun: &BuiltInWindowFunction,
+ _args: &[Arc<dyn PhysicalExpr>],
+ _input_schema: &Schema,
+ name: String,
+) -> Result<Arc<dyn BuiltInWindowFunctionExpr>> {
+ match fun {
+ BuiltInWindowFunction::RowNumber => Ok(Arc::new(RowNumber::new(name))),
+ _ => Err(DataFusionError::NotImplemented(format!(
+ "Window function with {:?} not yet implemented",
+ fun
+ ))),
}
}
/// A window expr that takes the form of a built in window function
#[derive(Debug)]
-pub struct BuiltInWindowExpr {}
+pub struct BuiltInWindowExpr {
+ window: Arc<dyn BuiltInWindowFunctionExpr>,
+}
+
+impl WindowExpr for BuiltInWindowExpr {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ &self.window.name()
+ }
+
+ fn field(&self) -> Result<Field> {
+ self.window.field()
+ }
+
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ self.window.expressions()
+ }
+
+ fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
+ self.window.create_accumulator()
+ }
+}
/// A window expr that takes the form of an aggregate function
#[derive(Debug)]
@@ -76,6 +130,23 @@ pub struct AggregateWindowExpr {
aggregate: Arc<dyn AggregateExpr>,
}
+#[derive(Debug)]
+struct AggregateWindowAccumulator {
+ accumulator: Box<dyn Accumulator>,
+}
+
+impl WindowAccumulator for AggregateWindowAccumulator {
+ fn scan(&mut self, values: &[ScalarValue]) -> Result<Option<ScalarValue>> {
+ self.accumulator.update(values)?;
+ Ok(None)
+ }
+
+ /// returns its value based on its current state.
+ fn evaluate(&self) -> Result<Option<ScalarValue>> {
+ Ok(Some(self.accumulator.evaluate()?))
+ }
+}
+
impl WindowExpr for AggregateWindowExpr {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
@@ -89,6 +160,15 @@ impl WindowExpr for AggregateWindowExpr {
fn field(&self) -> Result<Field> {
self.aggregate.field()
}
+
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ self.aggregate.expressions()
+ }
+
+ fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
+ let accumulator = self.aggregate.create_accumulator()?;
+ Ok(Box::new(AggregateWindowAccumulator { accumulator }))
+ }
}
fn create_schema(
@@ -120,12 +200,17 @@ impl WindowAggExec {
})
}
+ /// Window expressions
+ pub fn window_expr(&self) -> &[Arc<dyn WindowExpr>] {
+ &self.window_expr
+ }
+
/// Input plan
pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
&self.input
}
- /// Get the input schema before any aggregates are applied
+ /// Get the input schema before any window functions are applied
pub fn input_schema(&self) -> SchemaRef {
self.input_schema.clone()
}
@@ -163,7 +248,7 @@ impl ExecutionPlan for WindowAggExec {
1 => Ok(Arc::new(WindowAggExec::try_new(
self.window_expr.clone(),
children[0].clone(),
- children[0].schema(),
+ self.input_schema.clone(),
)?)),
_ => Err(DataFusionError::Internal(
"WindowAggExec wrong number of children".to_owned(),
@@ -186,10 +271,258 @@ impl ExecutionPlan for WindowAggExec {
));
}
- // let input = self.input.execute(0).await?;
+ let input = self.input.execute(partition).await?;
+
+ let stream = Box::pin(WindowAggStream::new(
+ self.schema.clone(),
+ self.window_expr.clone(),
+ input,
+ ));
+ Ok(stream)
+ }
+}
+
+pin_project! {
+ /// stream for window aggregation plan
+ pub struct WindowAggStream {
+ schema: SchemaRef,
+ #[pin]
+ output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
+ finished: bool,
+ }
+}
+
+type WindowAccumulatorItem = Box<dyn WindowAccumulator>;
+
+fn window_expressions(
+ window_expr: &[Arc<dyn WindowExpr>],
+) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
+ Ok(window_expr
+ .iter()
+ .map(|expr| expr.expressions())
+ .collect::<Vec<_>>())
+}
+
+fn window_aggregate_batch(
+ batch: &RecordBatch,
+ window_accumulators: &mut [WindowAccumulatorItem],
+ expressions: &[Vec<Arc<dyn PhysicalExpr>>],
+) -> Result<Vec<Option<ArrayRef>>> {
+ // 1.1 iterate accumulators and respective expressions together
+ // 1.2 evaluate expressions
+ // 1.3 update / merge window accumulators with the expressions' values
+
+ // 1.1
+ window_accumulators
+ .iter_mut()
+ .zip(expressions)
+ .map(|(window_acc, expr)| {
+ // 1.2
+ let values = &expr
+ .iter()
+ .map(|e| e.evaluate(batch))
+ .map(|r| r.map(|v| v.into_array(batch.num_rows())))
+ .collect::<Result<Vec<_>>>()?;
+
+ window_acc.scan_batch(batch.num_rows(), values)
+ })
+ .into_iter()
+ .collect::<Result<Vec<_>>>()
+}
+
+/// returns a vector of ArrayRefs, where each entry corresponds to one window expr
+fn finalize_window_aggregation(
+ window_accumulators: &[WindowAccumulatorItem],
+) -> Result<Vec<Option<ScalarValue>>> {
+ window_accumulators
+ .iter()
+ .map(|window_accumulator| window_accumulator.evaluate())
+ .collect::<Result<Vec<_>>>()
+}
+
+fn create_window_accumulators(
+ window_expr: &[Arc<dyn WindowExpr>],
+) -> Result<Vec<WindowAccumulatorItem>> {
+ window_expr
+ .iter()
+ .map(|expr| expr.create_accumulator())
+ .collect::<Result<Vec<_>>>()
+}
+
+async fn compute_window_aggregate(
+ schema: SchemaRef,
+ window_expr: Vec<Arc<dyn WindowExpr>>,
+ mut input: SendableRecordBatchStream,
+) -> ArrowResult<RecordBatch> {
+ let mut window_accumulators = create_window_accumulators(&window_expr)
+ .map_err(DataFusionError::into_arrow_external_error)?;
+
+ let expressions = window_expressions(&window_expr)
+ .map_err(DataFusionError::into_arrow_external_error)?;
+
+ let expressions = Arc::new(expressions);
+
+ // TODO each element shall have some size hint
+ let mut accumulator: Vec<Vec<ArrayRef>> =
+ iter::repeat(vec![]).take(window_expr.len()).collect();
+
+ let mut original_batches: Vec<RecordBatch> = vec![];
+
+ let mut total_num_rows = 0;
+
+ while let Some(batch) = input.next().await {
+ let batch = batch?;
+ total_num_rows += batch.num_rows();
+ original_batches.push(batch.clone());
+
+ let batch_aggregated =
+ window_aggregate_batch(&batch, &mut window_accumulators, &expressions)
+ .map_err(DataFusionError::into_arrow_external_error)?;
+ accumulator.iter_mut().zip(batch_aggregated).for_each(
+ |(acc_for_window, window_batch)| {
+ if let Some(data) = window_batch {
+ acc_for_window.push(data);
+ }
+ },
+ );
+ }
+
+ let aggregated_mapped = finalize_window_aggregation(&window_accumulators)
+ .map_err(DataFusionError::into_arrow_external_error)?;
+
+ let mut columns: Vec<ArrayRef> = accumulator
+ .iter()
+ .zip(aggregated_mapped)
+ .map(|(acc, agg)| {
+ Ok(match (acc, agg) {
+ (acc, Some(scalar_value)) if acc.is_empty() => {
+ scalar_value.to_array_of_size(total_num_rows)
+ }
+ (acc, None) if !acc.is_empty() => {
+ let vec_array: Vec<&dyn Array> =
+ acc.iter().map(|arc| arc.as_ref()).collect();
+ concat(&vec_array)?
+ }
+ _ => {
+ return Err(DataFusionError::Execution(
+ "Invalid window function behavior".to_owned(),
+ ))
+ }
+ })
+ })
+ .collect::<Result<Vec<ArrayRef>>>()
+ .map_err(DataFusionError::into_arrow_external_error)?;
+
+ for i in 0..(schema.fields().len() - window_expr.len()) {
+ let col = concat(
+ &original_batches
+ .iter()
+ .map(|batch| batch.column(i).as_ref())
+ .collect::<Vec<_>>(),
+ )?;
+ columns.push(col);
+ }
+
+ RecordBatch::try_new(schema.clone(), columns)
+}
+
+impl WindowAggStream {
+ /// Create a new WindowAggStream
+ pub fn new(
+ schema: SchemaRef,
+ window_expr: Vec<Arc<dyn WindowExpr>>,
+ input: SendableRecordBatchStream,
+ ) -> Self {
+ let (tx, rx) = futures::channel::oneshot::channel();
+ let schema_clone = schema.clone();
+ tokio::spawn(async move {
+ let result = compute_window_aggregate(schema_clone, window_expr, input).await;
+ tx.send(result)
+ });
+
+ Self {
+ output: rx,
+ finished: false,
+ schema,
+ }
+ }
+}
+
+impl Stream for WindowAggStream {
+ type Item = ArrowResult<RecordBatch>;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ if self.finished {
+ return Poll::Ready(None);
+ }
- Err(DataFusionError::NotImplemented(
- "WindowAggExec::execute".to_owned(),
- ))
+ // is the output ready?
+ let this = self.project();
+ let output_poll = this.output.poll(cx);
+
+ match output_poll {
+ Poll::Ready(result) => {
+ *this.finished = true;
+ // check for error in receiving channel and unwrap actual result
+ let result = match result {
+ Err(e) => Some(Err(ArrowError::ExternalError(Box::new(e)))), // error receiving
+ Ok(result) => Some(result),
+ };
+ Poll::Ready(result)
+ }
+ Poll::Pending => Poll::Pending,
+ }
+ }
+}
+
+impl RecordBatchStream for WindowAggStream {
+ /// Get the schema
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
}
}
+
+#[cfg(test)]
+mod tests {
+ // use super::*;
+
+ // /// some mock data to test windows
+ // fn some_data() -> (Arc<Schema>, Vec<RecordBatch>) {
+ // // define a schema.
+ // let schema = Arc::new(Schema::new(vec![
+ // Field::new("a", DataType::UInt32, false),
+ // Field::new("b", DataType::Float64, false),
+ // ]));
+
+ // // define data.
+ // (
+ // schema.clone(),
+ // vec![
+ // RecordBatch::try_new(
+ // schema.clone(),
+ // vec![
+ // Arc::new(UInt32Array::from(vec![2, 3, 4, 4])),
+ // Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
+ // ],
+ // )
+ // .unwrap(),
+ // RecordBatch::try_new(
+ // schema,
+ // vec![
+ // Arc::new(UInt32Array::from(vec![2, 3, 3, 4])),
+ // Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
+ // ],
+ // )
+ // .unwrap(),
+ // ],
+ // )
+ // }
+
+ // #[tokio::test]
+ // async fn window_function() -> Result<()> {
+ // let input: Arc<dyn ExecutionPlan> = unimplemented!();
+ // let input_schema = input.schema();
+ // let window_expr = vec![];
+ // WindowAggExec::try_new(window_expr, input, input_schema);
+ // }
+}
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index e68c53b..55bc88e 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -797,20 +797,31 @@ async fn csv_query_count() -> Result<()> {
Ok(())
}
-// FIXME uncomment this when exec is done
-// #[tokio::test]
-// async fn csv_query_window_with_empty_over() -> Result<()> {
-// let mut ctx = ExecutionContext::new();
-// register_aggregate_csv(&mut ctx)?;
-// let sql = "SELECT count(c12) over () FROM aggregate_test_100";
-// // FIXME: so far the WindowAggExec is not implemented
-// // and the current behavior is to throw not implemented exception
-
-// let result = execute(&mut ctx, sql).await;
-// let expected: Vec<Vec<String>> = vec![];
-// assert_eq!(result, expected);
-// Ok(())
-// }
+#[tokio::test]
+async fn csv_query_window_with_empty_over() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx)?;
+ let sql = "select \
+ c2, \
+ sum(c3) over (), \
+ avg(c3) over (), \
+ count(c3) over (), \
+ max(c3) over (), \
+ min(c3) over () \
+ from aggregate_test_100 \
+ order by c2 \
+ limit 5";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![
+ vec!["1", "781", "7.81", "100", "125", "-117"],
+ vec!["1", "781", "7.81", "100", "125", "-117"],
+ vec!["1", "781", "7.81", "100", "125", "-117"],
+ vec!["1", "781", "7.81", "100", "125", "-117"],
+ vec!["1", "781", "7.81", "100", "125", "-117"],
+ ];
+ assert_eq!(expected, actual);
+ Ok(())
+}
#[tokio::test]
async fn csv_query_group_by_int_count() -> Result<()> {
diff --git a/parquet-testing b/parquet-testing
index 8e7badc..ddd8989 160000
--- a/parquet-testing
+++ b/parquet-testing
@@ -1 +1 @@
-Subproject commit 8e7badc6a3817a02e06d17b5d8ab6b6dc356e890
+Subproject commit ddd898958803cb89b7156c6350584d1cda0fe8de