You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by co...@apache.org on 2024/02/23 03:16:35 UTC
(arrow-datafusion) branch main updated: Support IGNORE NULLS for LAG window function (#9221)
This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new a851ecf1cc Support IGNORE NULLS for LAG window function (#9221)
a851ecf1cc is described below
commit a851ecf1cc24a6b867d40087d8e890b9307137c1
Author: comphead <co...@users.noreply.github.com>
AuthorDate: Thu Feb 22 19:16:30 2024 -0800
Support IGNORE NULLS for LAG window function (#9221)
* WIP lag/lead ignore nulls
* Support IGNORE NULLS for LAG function
* fmt
* comments
* remove comments
* Add new tests, minor changes, trigger evalaute_all
* Make algorithm pruning friendly
---------
Co-authored-by: Mustafa Akur <mu...@synnada.ai>
---
datafusion/core/src/dataframe/mod.rs | 1 +
.../core/src/physical_optimizer/test_utils.rs | 1 +
datafusion/core/src/physical_planner.rs | 6 ++
datafusion/core/tests/dataframe/mod.rs | 1 +
datafusion/core/tests/fuzz_cases/window_fuzz.rs | 3 +
datafusion/expr/src/expr.rs | 18 ++++
datafusion/expr/src/tree_node/expr.rs | 2 +
datafusion/expr/src/udwf.rs | 1 +
datafusion/expr/src/utils.rs | 10 ++
.../optimizer/src/analyzer/count_wildcard_rule.rs | 3 +
datafusion/optimizer/src/analyzer/type_coercion.rs | 2 +
datafusion/optimizer/src/push_down_projection.rs | 2 +
datafusion/physical-expr/src/window/lead_lag.rs | 88 +++++++++++++++--
datafusion/physical-plan/src/windows/mod.rs | 8 +-
datafusion/proto/src/logical_plan/from_proto.rs | 6 ++
datafusion/proto/src/logical_plan/to_proto.rs | 2 +
datafusion/proto/src/physical_plan/from_proto.rs | 1 +
.../proto/tests/cases/roundtrip_logical_plan.rs | 6 ++
datafusion/sql/src/expr/function.rs | 16 +--
datafusion/sqllogictest/test_files/window.slt | 107 +++++++++++++++++++++
datafusion/substrait/src/logical_plan/consumer.rs | 1 +
datafusion/substrait/src/logical_plan/producer.rs | 1 +
22 files changed, 272 insertions(+), 14 deletions(-)
diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs
index 4ec16ac942..e407c477ae 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -1685,6 +1685,7 @@ mod tests {
vec![col("aggregate_test_100.c2")],
vec![],
WindowFrame::new(None),
+ None,
));
let t2 = t.select(vec![col("c1"), first_row])?;
let plan = t2.plan.clone();
diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs
index ca7fb78d21..3898fb6345 100644
--- a/datafusion/core/src/physical_optimizer/test_utils.rs
+++ b/datafusion/core/src/physical_optimizer/test_utils.rs
@@ -245,6 +245,7 @@ pub fn bounded_window_exec(
&sort_exprs,
Arc::new(WindowFrame::new(Some(false))),
schema.as_ref(),
+ false,
)
.unwrap()],
input.clone(),
diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs
index dabf0a91b2..23ac7e08ca 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -100,6 +100,7 @@ use futures::future::BoxFuture;
use futures::{FutureExt, StreamExt, TryStreamExt};
use itertools::{multiunzip, Itertools};
use log::{debug, trace};
+use sqlparser::ast::NullTreatment;
fn create_function_physical_name(
fun: &str,
@@ -1581,6 +1582,7 @@ pub fn create_window_expr_with_name(
partition_by,
order_by,
window_frame,
+ null_treatment,
}) => {
let args = args
.iter()
@@ -1605,6 +1607,9 @@ pub fn create_window_expr_with_name(
}
let window_frame = Arc::new(window_frame.clone());
+ let ignore_nulls = null_treatment
+ .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
+ == NullTreatment::IgnoreNulls;
windows::create_window_expr(
fun,
name,
@@ -1613,6 +1618,7 @@ pub fn create_window_expr_with_name(
&order_by,
window_frame,
physical_input_schema,
+ ignore_nulls,
)
}
other => plan_err!("Invalid window expression '{other:?}'"),
diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs
index f650e9e39d..b08b2b8fc7 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -182,6 +182,7 @@ async fn test_count_wildcard_on_window() -> Result<()> {
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
),
+ None,
))])?
.explain(false, false)?
.collect()
diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
index d22d0c0f2e..609d26c9c2 100644
--- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
@@ -281,6 +281,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
&orderby_exprs,
Arc::new(window_frame),
schema.as_ref(),
+ false,
)?;
let running_window_exec = Arc::new(BoundedWindowAggExec::try_new(
vec![window_expr],
@@ -642,6 +643,7 @@ async fn run_window_test(
&orderby_exprs,
Arc::new(window_frame.clone()),
schema.as_ref(),
+ false,
)
.unwrap()],
exec1,
@@ -664,6 +666,7 @@ async fn run_window_test(
&orderby_exprs,
Arc::new(window_frame.clone()),
schema.as_ref(),
+ false,
)
.unwrap()],
exec2,
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 09de4b708d..f40ccb6cdb 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -30,6 +30,7 @@ use arrow::datatypes::DataType;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{internal_err, DFSchema, OwnedTableReference};
use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue};
+use sqlparser::ast::NullTreatment;
use std::collections::HashSet;
use std::fmt;
use std::fmt::{Display, Formatter, Write};
@@ -646,6 +647,7 @@ pub struct WindowFunction {
pub order_by: Vec<Expr>,
/// Window frame
pub window_frame: window_frame::WindowFrame,
+ pub null_treatment: Option<NullTreatment>,
}
impl WindowFunction {
@@ -656,6 +658,7 @@ impl WindowFunction {
partition_by: Vec<Expr>,
order_by: Vec<Expr>,
window_frame: window_frame::WindowFrame,
+ null_treatment: Option<NullTreatment>,
) -> Self {
Self {
fun,
@@ -663,6 +666,7 @@ impl WindowFunction {
partition_by,
order_by,
window_frame,
+ null_treatment,
}
}
}
@@ -1440,8 +1444,14 @@ impl fmt::Display for Expr {
partition_by,
order_by,
window_frame,
+ null_treatment,
}) => {
fmt_function(f, &fun.to_string(), false, args, true)?;
+
+ if let Some(nt) = null_treatment {
+ write!(f, "{}", nt)?;
+ }
+
if !partition_by.is_empty() {
write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?;
}
@@ -1768,15 +1778,23 @@ fn create_name(e: &Expr) -> Result<String> {
window_frame,
partition_by,
order_by,
+ null_treatment,
}) => {
let mut parts: Vec<String> =
vec![create_function_name(&fun.to_string(), false, args)?];
+
+ if let Some(nt) = null_treatment {
+ parts.push(format!("{}", nt));
+ }
+
if !partition_by.is_empty() {
parts.push(format!("PARTITION BY [{}]", expr_vec_fmt!(partition_by)));
}
+
if !order_by.is_empty() {
parts.push(format!("ORDER BY [{}]", expr_vec_fmt!(order_by)));
}
+
parts.push(format!("{window_frame}"));
Ok(parts.join(" "))
}
diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs
index add15b3d7a..def25ed924 100644
--- a/datafusion/expr/src/tree_node/expr.rs
+++ b/datafusion/expr/src/tree_node/expr.rs
@@ -283,12 +283,14 @@ impl TreeNode for Expr {
partition_by,
order_by,
window_frame,
+ null_treatment,
}) => Expr::WindowFunction(WindowFunction::new(
fun,
transform_vec(args, &mut transform)?,
transform_vec(partition_by, &mut transform)?,
transform_vec(order_by, &mut transform)?,
window_frame,
+ null_treatment,
)),
Expr::AggregateFunction(AggregateFunction {
args,
diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs
index 9534834088..7e3eb6c001 100644
--- a/datafusion/expr/src/udwf.rs
+++ b/datafusion/expr/src/udwf.rs
@@ -130,6 +130,7 @@ impl WindowUDF {
partition_by,
order_by,
window_frame,
+ null_treatment: None,
})
}
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index e855554f36..2fda81d889 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -1255,6 +1255,7 @@ mod tests {
vec![],
vec![],
WindowFrame::new(None),
+ None,
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
@@ -1262,6 +1263,7 @@ mod tests {
vec![],
vec![],
WindowFrame::new(None),
+ None,
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
@@ -1269,6 +1271,7 @@ mod tests {
vec![],
vec![],
WindowFrame::new(None),
+ None,
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
@@ -1276,6 +1279,7 @@ mod tests {
vec![],
vec![],
WindowFrame::new(None),
+ None,
));
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
@@ -1298,6 +1302,7 @@ mod tests {
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(Some(false)),
+ None,
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
@@ -1305,6 +1310,7 @@ mod tests {
vec![],
vec![],
WindowFrame::new(None),
+ None,
));
let min3 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
@@ -1312,6 +1318,7 @@ mod tests {
vec![],
vec![age_asc.clone(), name_desc.clone()],
WindowFrame::new(Some(false)),
+ None,
));
let sum4 = Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
@@ -1319,6 +1326,7 @@ mod tests {
vec![],
vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()],
WindowFrame::new(Some(false)),
+ None,
));
// FIXME use as_ref
let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
@@ -1353,6 +1361,7 @@ mod tests {
Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)),
],
WindowFrame::new(Some(false)),
+ None,
)),
Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum),
@@ -1364,6 +1373,7 @@ mod tests {
Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)),
],
WindowFrame::new(Some(false)),
+ None,
)),
];
let expected = vec![
diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
index 35a8597832..9242e68562 100644
--- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
+++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
@@ -128,6 +128,7 @@ impl TreeNodeRewriter for CountWildcardRewriter {
partition_by,
order_by,
window_frame,
+ null_treatment,
}) if args.len() == 1 => match args[0] {
Expr::Wildcard { qualifier: None } => {
Expr::WindowFunction(expr::WindowFunction {
@@ -138,6 +139,7 @@ impl TreeNodeRewriter for CountWildcardRewriter {
partition_by,
order_by,
window_frame,
+ null_treatment,
})
}
@@ -351,6 +353,7 @@ mod tests {
WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
),
+ None,
))])?
.project(vec![count(wildcard())])?
.build()?;
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs
index fba77047dd..8cdb4d7dbd 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -392,6 +392,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
partition_by,
order_by,
window_frame,
+ null_treatment,
}) => {
let window_frame =
coerce_window_frame(window_frame, &self.schema, &order_by)?;
@@ -414,6 +415,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
partition_by,
order_by,
window_frame,
+ null_treatment,
));
Ok(expr)
}
diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs
index 6a003ecb5f..8b7a9148b5 100644
--- a/datafusion/optimizer/src/push_down_projection.rs
+++ b/datafusion/optimizer/src/push_down_projection.rs
@@ -587,6 +587,7 @@ mod tests {
vec![col("test.b")],
vec![],
WindowFrame::new(None),
+ None,
));
let max2 = Expr::WindowFunction(expr::WindowFunction::new(
@@ -595,6 +596,7 @@ mod tests {
vec![],
vec![],
WindowFrame::new(None),
+ None,
));
let col1 = col(max1.display_name()?);
let col2 = col(max2.display_name()?);
diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs
index 6a33f26ca1..6e1aad575f 100644
--- a/datafusion/physical-expr/src/window/lead_lag.rs
+++ b/datafusion/physical-expr/src/window/lead_lag.rs
@@ -23,10 +23,14 @@ use crate::PhysicalExpr;
use arrow::array::ArrayRef;
use arrow::compute::cast;
use arrow::datatypes::{DataType, Field};
-use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue};
+use arrow_array::Array;
+use datafusion_common::{
+ arrow_datafusion_err, exec_datafusion_err, DataFusionError, Result, ScalarValue,
+};
use datafusion_expr::PartitionEvaluator;
use std::any::Any;
use std::cmp::min;
+use std::collections::VecDeque;
use std::ops::{Neg, Range};
use std::sync::Arc;
@@ -39,6 +43,7 @@ pub struct WindowShift {
shift_offset: i64,
expr: Arc<dyn PhysicalExpr>,
default_value: Option<ScalarValue>,
+ ignore_nulls: bool,
}
impl WindowShift {
@@ -60,6 +65,7 @@ pub fn lead(
expr: Arc<dyn PhysicalExpr>,
shift_offset: Option<i64>,
default_value: Option<ScalarValue>,
+ ignore_nulls: bool,
) -> WindowShift {
WindowShift {
name,
@@ -67,6 +73,7 @@ pub fn lead(
shift_offset: shift_offset.map(|v| v.neg()).unwrap_or(-1),
expr,
default_value,
+ ignore_nulls,
}
}
@@ -77,6 +84,7 @@ pub fn lag(
expr: Arc<dyn PhysicalExpr>,
shift_offset: Option<i64>,
default_value: Option<ScalarValue>,
+ ignore_nulls: bool,
) -> WindowShift {
WindowShift {
name,
@@ -84,6 +92,7 @@ pub fn lag(
shift_offset: shift_offset.unwrap_or(1),
expr,
default_value,
+ ignore_nulls,
}
}
@@ -110,6 +119,8 @@ impl BuiltInWindowFunctionExpr for WindowShift {
Ok(Box::new(WindowShiftEvaluator {
shift_offset: self.shift_offset,
default_value: self.default_value.clone(),
+ ignore_nulls: self.ignore_nulls,
+ non_null_offsets: VecDeque::new(),
}))
}
@@ -120,6 +131,7 @@ impl BuiltInWindowFunctionExpr for WindowShift {
shift_offset: -self.shift_offset,
expr: self.expr.clone(),
default_value: self.default_value.clone(),
+ ignore_nulls: self.ignore_nulls,
}))
}
}
@@ -128,6 +140,16 @@ impl BuiltInWindowFunctionExpr for WindowShift {
pub(crate) struct WindowShiftEvaluator {
shift_offset: i64,
default_value: Option<ScalarValue>,
+ ignore_nulls: bool,
+ // VecDeque contains offset values that between non-null entries
+ non_null_offsets: VecDeque<usize>,
+}
+
+impl WindowShiftEvaluator {
+ fn is_lag(&self) -> bool {
+ // Mode is LAG, when shift_offset is positive
+ self.shift_offset > 0
+ }
}
fn create_empty_array(
@@ -182,9 +204,13 @@ fn shift_with_default_value(
impl PartitionEvaluator for WindowShiftEvaluator {
fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
- if self.shift_offset > 0 {
- let offset = self.shift_offset as usize;
- let start = idx.saturating_sub(offset);
+ if self.is_lag() {
+ let start = if self.non_null_offsets.len() == self.shift_offset as usize {
+ let offset: usize = self.non_null_offsets.iter().sum();
+ idx.saturating_sub(offset + 1)
+ } else {
+ 0
+ };
let end = idx + 1;
Ok(Range { start, end })
} else {
@@ -196,7 +222,7 @@ impl PartitionEvaluator for WindowShiftEvaluator {
fn is_causal(&self) -> bool {
// Lagging windows are causal by definition:
- self.shift_offset > 0
+ self.is_lag()
}
fn evaluate(
@@ -204,17 +230,57 @@ impl PartitionEvaluator for WindowShiftEvaluator {
values: &[ArrayRef],
range: &Range<usize>,
) -> Result<ScalarValue> {
+ // TODO: try to get rid of i64 usize conversion
+ // TODO: do not recalculate default value every call
+ // TODO: support LEAD mode for IGNORE NULLS
let array = &values[0];
let dtype = array.data_type();
+ let len = array.len() as i64;
// LAG mode
- let idx = if self.shift_offset > 0 {
+ let mut idx = if self.is_lag() {
range.end as i64 - self.shift_offset - 1
} else {
// LEAD mode
range.start as i64 - self.shift_offset
};
- if idx < 0 || idx as usize >= array.len() {
+ // Support LAG only for now, as LEAD requires some brainstorm first
+ // LAG with IGNORE NULLS calculated as the current row index - offset, but only for non-NULL rows
+ // If current row index points to NULL value the row is NOT counted
+ if self.ignore_nulls && self.is_lag() {
+ // Find the nonNULL row index that shifted by offset comparing to current row index
+ idx = if self.non_null_offsets.len() == self.shift_offset as usize {
+ let total_offset: usize = self.non_null_offsets.iter().sum();
+ (range.end - 1 - total_offset) as i64
+ } else {
+ -1
+ };
+
+ // Keep track of offset values between non-null entries
+ if array.is_valid(range.end - 1) {
+ // Non-null add new offset
+ self.non_null_offsets.push_back(1);
+ if self.non_null_offsets.len() > self.shift_offset as usize {
+ // WE do not need to keep track of more than `lag number of offset` values.
+ self.non_null_offsets.pop_front();
+ }
+ } else if !self.non_null_offsets.is_empty() {
+ // Entry is null, increment offset value of the last entry.
+ let end_idx = self.non_null_offsets.len() - 1;
+ self.non_null_offsets[end_idx] += 1;
+ }
+ } else if self.ignore_nulls && !self.is_lag() {
+ // IGNORE NULLS and LEAD mode.
+ return Err(exec_datafusion_err!(
+ "IGNORE NULLS mode for LEAD is not supported for BoundedWindowAggExec"
+ ));
+ }
+
+ // Set the default value if
+ // - index is out of window bounds
+ // OR
+ // - ignore nulls mode and current value is null and is within window bounds
+ if idx < 0 || idx >= len || (self.ignore_nulls && array.is_null(idx as usize)) {
get_default_value(self.default_value.as_ref(), dtype)
} else {
ScalarValue::try_from_array(array, idx as usize)
@@ -226,6 +292,11 @@ impl PartitionEvaluator for WindowShiftEvaluator {
values: &[ArrayRef],
_num_rows: usize,
) -> Result<ArrayRef> {
+ if self.ignore_nulls {
+ return Err(exec_datafusion_err!(
+ "IGNORE NULLS mode for LAG and LEAD is not supported for WindowAggExec"
+ ));
+ }
// LEAD, LAG window functions take single column, values will have size 1
let value = &values[0];
shift_with_default_value(value, self.shift_offset, self.default_value.as_ref())
@@ -279,6 +350,7 @@ mod tests {
Arc::new(Column::new("c3", 0)),
None,
None,
+ false,
),
[
Some(-2),
@@ -301,6 +373,7 @@ mod tests {
Arc::new(Column::new("c3", 0)),
None,
None,
+ false,
),
[
None,
@@ -323,6 +396,7 @@ mod tests {
Arc::new(Column::new("c3", 0)),
None,
Some(ScalarValue::Int32(Some(100))),
+ false,
),
[
Some(100),
diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs
index 693d20e90a..bf6ed92535 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -55,6 +55,7 @@ pub use datafusion_physical_expr::window::{
};
/// Create a physical expression for window function
+#[allow(clippy::too_many_arguments)]
pub fn create_window_expr(
fun: &WindowFunctionDefinition,
name: String,
@@ -63,6 +64,7 @@ pub fn create_window_expr(
order_by: &[PhysicalSortExpr],
window_frame: Arc<WindowFrame>,
input_schema: &Schema,
+ ignore_nulls: bool,
) -> Result<Arc<dyn WindowExpr>> {
Ok(match fun {
WindowFunctionDefinition::AggregateFunction(fun) => {
@@ -83,7 +85,7 @@ pub fn create_window_expr(
}
WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
Arc::new(BuiltInWindowExpr::new(
- create_built_in_window_expr(fun, args, input_schema, name)?,
+ create_built_in_window_expr(fun, args, input_schema, name, ignore_nulls)?,
partition_by,
order_by,
window_frame,
@@ -159,6 +161,7 @@ fn create_built_in_window_expr(
args: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,
name: String,
+ ignore_nulls: bool,
) -> Result<Arc<dyn BuiltInWindowFunctionExpr>> {
// need to get the types into an owned vec for some reason
let input_types: Vec<_> = args
@@ -208,6 +211,7 @@ fn create_built_in_window_expr(
arg,
shift_offset,
default_value,
+ ignore_nulls,
))
}
BuiltInWindowFunction::Lead => {
@@ -222,6 +226,7 @@ fn create_built_in_window_expr(
arg,
shift_offset,
default_value,
+ ignore_nulls,
))
}
BuiltInWindowFunction::NthValue => {
@@ -671,6 +676,7 @@ mod tests {
&[],
Arc::new(WindowFrame::new(None)),
schema.as_ref(),
+ false,
)?],
blocking_exec,
vec![],
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs
index f1ee84a822..2554018a92 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -1100,6 +1100,8 @@ pub fn parse_expr(
"missing window frame during deserialization".to_string(),
)
})?;
+ // TODO: support proto for null treatment
+ let null_treatment = None;
regularize_window_order_by(&window_frame, &mut order_by)?;
match window_function {
@@ -1114,6 +1116,7 @@ pub fn parse_expr(
partition_by,
order_by,
window_frame,
+ None
)))
}
window_expr_node::WindowFunction::BuiltInFunction(i) => {
@@ -1133,6 +1136,7 @@ pub fn parse_expr(
partition_by,
order_by,
window_frame,
+ null_treatment
)))
}
window_expr_node::WindowFunction::Udaf(udaf_name) => {
@@ -1148,6 +1152,7 @@ pub fn parse_expr(
partition_by,
order_by,
window_frame,
+ None,
)))
}
window_expr_node::WindowFunction::Udwf(udwf_name) => {
@@ -1163,6 +1168,7 @@ pub fn parse_expr(
partition_by,
order_by,
window_frame,
+ None,
)))
}
}
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs
index a6348e909c..ccadbb217a 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -606,6 +606,8 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
ref partition_by,
ref order_by,
ref window_frame,
+ // TODO: support null treatment in proto
+ null_treatment: _,
}) => {
let window_function = match fun {
WindowFunctionDefinition::AggregateFunction(fun) => {
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs
index 628ee5ad9b..af0aa485c3 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -176,6 +176,7 @@ pub fn parse_physical_window_expr(
&order_by,
Arc::new(window_frame),
input_schema,
+ false,
)
}
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 81f5997547..6ca7579081 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -1718,6 +1718,7 @@ fn roundtrip_window() {
vec![col("col1")],
vec![col("col2")],
WindowFrame::new(Some(false)),
+ None,
));
// 2. with default window_frame
@@ -1729,6 +1730,7 @@ fn roundtrip_window() {
vec![col("col1")],
vec![col("col2")],
WindowFrame::new(Some(false)),
+ None,
));
// 3. with window_frame with row numbers
@@ -1746,6 +1748,7 @@ fn roundtrip_window() {
vec![col("col1")],
vec![col("col2")],
range_number_frame,
+ None,
));
// 4. test with AggregateFunction
@@ -1761,6 +1764,7 @@ fn roundtrip_window() {
vec![col("col1")],
vec![col("col2")],
row_number_frame.clone(),
+ None,
));
// 5. test with AggregateUDF
@@ -1812,6 +1816,7 @@ fn roundtrip_window() {
vec![col("col1")],
vec![col("col2")],
row_number_frame.clone(),
+ None,
));
ctx.register_udaf(dummy_agg);
@@ -1887,6 +1892,7 @@ fn roundtrip_window() {
vec![col("col1")],
vec![col("col2")],
row_number_frame,
+ None,
));
ctx.register_udwf(dummy_window_udf);
diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs
index 64b8d6957d..f56138066c 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -52,8 +52,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
order_by,
} = function;
- if let Some(null_treatment) = null_treatment {
- return not_impl_err!("Null treatment in aggregate functions is not supported: {null_treatment}");
+ // If function is a window function (it has an OVER clause),
+ // it shouldn't have ordering requirement as function argument
+ // required ordering should be defined in OVER clause.
+ let is_function_window = over.is_some();
+
+ match null_treatment {
+ Some(null_treatment) if !is_function_window => return not_impl_err!("Null treatment in aggregate functions is not supported: {null_treatment}"),
+ _ => {}
}
let name = if name.0.len() > 1 {
@@ -120,10 +126,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
return Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args)));
};
- // If function is a window function (it has an OVER clause),
- // it shouldn't have ordering requirement as function argument
- // required ordering should be defined in OVER clause.
- let is_function_window = over.is_some();
if !order_by.is_empty() && is_function_window {
return plan_err!(
"Aggregate ORDER BY is not implemented for window functions"
@@ -198,6 +200,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
partition_by,
order_by,
window_frame,
+ null_treatment,
))
}
_ => Expr::WindowFunction(expr::WindowFunction::new(
@@ -206,6 +209,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
partition_by,
order_by,
window_frame,
+ null_treatment,
)),
};
return Ok(expr);
diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt
index 9276f6e1e3..8d6b314747 100644
--- a/datafusion/sqllogictest/test_files/window.slt
+++ b/datafusion/sqllogictest/test_files/window.slt
@@ -4102,3 +4102,110 @@ ProjectionExec: expr=[ROW_NUMBER() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRE
----------CoalesceBatchesExec: target_batch_size=4096
------------FilterExec: a@0 = 1
--------------MemoryExec: partitions=1, partition_sizes=[1]
+
+# LAG window function IGNORE/RESPECT NULLS support with ascending order and default offset 1
+query TTTTTT
+select lag(a) ignore nulls over (order by id) as x,
+ lag(a, 1, null) ignore nulls over (order by id) as x1,
+ lag(a, 1, 'def') ignore nulls over (order by id) as x2,
+ lag(a) respect nulls over (order by id) as x3,
+ lag(a, 1, null) respect nulls over (order by id) as x4,
+ lag(a, 1, 'def') respect nulls over (order by id) as x5
+from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x')
+----
+NULL NULL def NULL NULL def
+NULL NULL def NULL NULL NULL
+b b b b b b
+b b b NULL NULL NULL
+
+# LAG window function IGNORE/RESPECT NULLS support with descending order and default offset 1
+query TTTTTT
+select lag(a) ignore nulls over (order by id desc) as x,
+ lag(a, 1, null) ignore nulls over (order by id desc) as x1,
+ lag(a, 1, 'def') ignore nulls over (order by id desc) as x2,
+ lag(a) respect nulls over (order by id desc) as x3,
+ lag(a, 1, null) respect nulls over (order by id desc) as x4,
+ lag(a, 1, 'def') respect nulls over (order by id desc) as x5
+from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x')
+----
+NULL NULL def NULL NULL def
+x x x x x x
+x x x NULL NULL NULL
+b b b b b b
+
+# LAG window function IGNORE/RESPECT NULLS support with ascending order and nondefault offset
+query TTTT
+select lag(a, 2, null) ignore nulls over (order by id) as x1,
+ lag(a, 2, 'def') ignore nulls over (order by id) as x2,
+ lag(a, 2, null) respect nulls over (order by id) as x4,
+ lag(a, 2, 'def') respect nulls over (order by id) as x5
+from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x')
+----
+NULL def NULL def
+NULL def NULL def
+NULL def NULL NULL
+NULL def b b
+
+# LAG window function IGNORE/RESPECT NULLS support with descending order and nondefault offset
+query TTTT
+select lag(a, 2, null) ignore nulls over (order by id desc) as x1,
+ lag(a, 2, 'def') ignore nulls over (order by id desc) as x2,
+ lag(a, 2, null) respect nulls over (order by id desc) as x4,
+ lag(a, 2, 'def') respect nulls over (order by id desc) as x5
+from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x')
+----
+NULL def NULL def
+NULL def NULL def
+NULL def x x
+x x NULL NULL
+
+# LAG window function IGNORE/RESPECT NULLS support with descending order and nondefault offset.
+# To trigger WindowAggExec, we added a sum window function with all of the ranges.
+statement error Execution error: IGNORE NULLS mode for LAG and LEAD is not supported for WindowAggExec
+select lag(a, 2, null) ignore nulls over (order by id desc) as x1,
+ lag(a, 2, 'def') ignore nulls over (order by id desc) as x2,
+ lag(a, 2, null) respect nulls over (order by id desc) as x4,
+ lag(a, 2, 'def') respect nulls over (order by id desc) as x5,
+ sum(id) over (order by id desc ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as sum_id
+from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x')
+
+# LEAD window function IGNORE/RESPECT NULLS support with descending order and nondefault offset
+statement error Execution error: IGNORE NULLS mode for LEAD is not supported for BoundedWindowAggExec
+select lead(a, 2, null) ignore nulls over (order by id desc) as x1,
+ lead(a, 2, 'def') ignore nulls over (order by id desc) as x2,
+ lead(a, 2, null) respect nulls over (order by id desc) as x4,
+ lead(a, 2, 'def') respect nulls over (order by id desc) as x5
+from (select 2 id, 'b' a union all select 1 id, null a union all select 3 id, null union all select 4 id, 'x')
+
+statement ok
+set datafusion.execution.batch_size = 1000;
+
+query I
+SELECT LAG(c1, 2) IGNORE NULLS OVER()
+FROM null_cases
+ORDER BY c2
+LIMIT 5;
+----
+78
+63
+3
+24
+14
+
+# result should be same with above, when lag algorithm work with pruned data.
+# decreasing batch size, causes data to be produced in smaller chunks at the source.
+# Hence sliding window algorithm is used during calculations.
+statement ok
+set datafusion.execution.batch_size = 1;
+
+query I
+SELECT LAG(c1, 2) IGNORE NULLS OVER()
+FROM null_cases
+ORDER BY c2
+LIMIT 5;
+----
+78
+63
+3
+24
+14
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs
index 58a741c634..23a7ee05d7 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -978,6 +978,7 @@ pub async fn from_substrait_rex(
from_substrait_bound(&window.lower_bound, true)?,
from_substrait_bound(&window.upper_bound, false)?,
),
+ null_treatment: None,
})))
}
Some(RexType::Subquery(subquery)) => match &subquery.as_ref().subquery_type {
diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs
index fc9517c90a..9b29c0c677 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -1115,6 +1115,7 @@ pub fn to_substrait_rex(
partition_by,
order_by,
window_frame,
+ null_treatment: _,
}) => {
// function reference
let function_anchor = _register_function(fun.to_string(), extension_info);