You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/11/29 19:49:16 UTC
[arrow-datafusion] branch master updated: Change the arg names and make parameters more meaningful (#1357)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 8aaf284 Change the arg names and make parameters more meaningful (#1357)
8aaf284 is described below
commit 8aaf284022509cb85df648892bae81c4daf26b7b
Author: Kun Liu <li...@apache.org>
AuthorDate: Tue Nov 30 03:48:13 2021 +0800
Change the arg names and make parameters more meaningful (#1357)
* change the name of function arg, and fix some parameter bug
* fix bug in the functions
---
datafusion/src/physical_plan/aggregates.rs | 100 ++++++++++++--------
datafusion/src/physical_plan/functions.rs | 114 ++++++++++++++---------
datafusion/src/physical_plan/type_coercion.rs | 10 +-
datafusion/src/physical_plan/udaf.rs | 10 +-
datafusion/src/physical_plan/udf.rs | 10 +-
datafusion/src/physical_plan/window_functions.rs | 17 ++--
6 files changed, 157 insertions(+), 104 deletions(-)
diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs
index 0c99c4f..1ec33a4 100644
--- a/datafusion/src/physical_plan/aggregates.rs
+++ b/datafusion/src/physical_plan/aggregates.rs
@@ -93,90 +93,116 @@ impl FromStr for AggregateFunction {
}
}
-/// Returns the datatype of the scalar function
-pub fn return_type(fun: &AggregateFunction, arg_types: &[DataType]) -> Result<DataType> {
+/// Returns the datatype of the aggregation function
+pub fn return_type(
+ fun: &AggregateFunction,
+ input_expr_types: &[DataType],
+) -> Result<DataType> {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.
// verify that this is a valid set of data types for this function
- data_types(arg_types, &signature(fun))?;
+ data_types(input_expr_types, &signature(fun))?;
match fun {
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
Ok(DataType::UInt64)
}
- AggregateFunction::Max | AggregateFunction::Min => Ok(arg_types[0].clone()),
- AggregateFunction::Sum => sum_return_type(&arg_types[0]),
- AggregateFunction::Avg => avg_return_type(&arg_types[0]),
+ AggregateFunction::Max | AggregateFunction::Min => {
+ Ok(input_expr_types[0].clone())
+ }
+ AggregateFunction::Sum => sum_return_type(&input_expr_types[0]),
+ AggregateFunction::Avg => avg_return_type(&input_expr_types[0]),
AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new(
"item",
- arg_types[0].clone(),
+ input_expr_types[0].clone(),
true,
)))),
}
}
-/// Create a physical (function) expression.
-/// This function errors when `args`' can't be coerced to a valid argument type of the function.
+/// Create a physical aggregation expression.
+/// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function.
pub fn create_aggregate_expr(
fun: &AggregateFunction,
distinct: bool,
- args: &[Arc<dyn PhysicalExpr>],
+ input_phy_exprs: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,
name: impl Into<String>,
) -> Result<Arc<dyn AggregateExpr>> {
let name = name.into();
- let arg = coerce(args, input_schema, &signature(fun))?;
- if arg.is_empty() {
+ let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &signature(fun))?;
+ if coerced_phy_exprs.is_empty() {
return Err(DataFusionError::Plan(format!(
"Invalid or wrong number of arguments passed to aggregate: '{}'",
name,
)));
}
- let arg = arg[0].clone();
- let arg_types = args
+ let coerced_exprs_types = coerced_phy_exprs
+ .iter()
+ .map(|e| e.data_type(input_schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ let input_exprs_types = input_phy_exprs
.iter()
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;
- let return_type = return_type(fun, &arg_types)?;
+ // In order to get the result data type, we must use the original input data type to calculate the result type.
+ let return_type = return_type(fun, &input_exprs_types)?;
Ok(match (fun, distinct) {
- (AggregateFunction::Count, false) => {
- Arc::new(expressions::Count::new(arg, name, return_type))
- }
+ (AggregateFunction::Count, false) => Arc::new(expressions::Count::new(
+ coerced_phy_exprs[0].clone(),
+ name,
+ return_type,
+ )),
(AggregateFunction::Count, true) => {
Arc::new(distinct_expressions::DistinctCount::new(
- arg_types,
- args.to_vec(),
+ coerced_exprs_types,
+ coerced_phy_exprs.to_vec(),
name,
return_type,
))
}
- (AggregateFunction::Sum, false) => {
- Arc::new(expressions::Sum::new(arg, name, return_type))
- }
+ (AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new(
+ coerced_phy_exprs[0].clone(),
+ name,
+ return_type,
+ )),
(AggregateFunction::Sum, true) => {
return Err(DataFusionError::NotImplemented(
"SUM(DISTINCT) aggregations are not available".to_string(),
));
}
- (AggregateFunction::ApproxDistinct, _) => Arc::new(
- expressions::ApproxDistinct::new(arg, name, arg_types[0].clone()),
- ),
- (AggregateFunction::ArrayAgg, _) => {
- Arc::new(expressions::ArrayAgg::new(arg, name, arg_types[0].clone()))
- }
- (AggregateFunction::Min, _) => {
- Arc::new(expressions::Min::new(arg, name, return_type))
- }
- (AggregateFunction::Max, _) => {
- Arc::new(expressions::Max::new(arg, name, return_type))
- }
- (AggregateFunction::Avg, false) => {
- Arc::new(expressions::Avg::new(arg, name, return_type))
+ (AggregateFunction::ApproxDistinct, _) => {
+ Arc::new(expressions::ApproxDistinct::new(
+ coerced_phy_exprs[0].clone(),
+ name,
+ coerced_exprs_types[0].clone(),
+ ))
}
+ (AggregateFunction::ArrayAgg, _) => Arc::new(expressions::ArrayAgg::new(
+ coerced_phy_exprs[0].clone(),
+ name,
+ coerced_exprs_types[0].clone(),
+ )),
+ (AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
+ coerced_phy_exprs[0].clone(),
+ name,
+ return_type,
+ )),
+ (AggregateFunction::Max, _) => Arc::new(expressions::Max::new(
+ coerced_phy_exprs[0].clone(),
+ name,
+ return_type,
+ )),
+ (AggregateFunction::Avg, false) => Arc::new(expressions::Avg::new(
+ coerced_phy_exprs[0].clone(),
+ name,
+ return_type,
+ )),
(AggregateFunction::Avg, true) => {
return Err(DataFusionError::NotImplemented(
"AVG(DISTINCT) aggregations are not available".to_string(),
diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs
index 72b2635..9c59b96 100644
--- a/datafusion/src/physical_plan/functions.rs
+++ b/datafusion/src/physical_plan/functions.rs
@@ -501,12 +501,12 @@ make_utf8_to_return_type!(utf8_to_binary_type, DataType::Binary, DataType::Binar
/// Returns the datatype of the scalar function
pub fn return_type(
fun: &BuiltinScalarFunction,
- arg_types: &[DataType],
+ input_expr_types: &[DataType],
) -> Result<DataType> {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.
- if arg_types.is_empty() && !fun.supports_zero_argument() {
+ if input_expr_types.is_empty() && !fun.supports_zero_argument() {
return Err(DataFusionError::Internal(format!(
"Builtin scalar function {} does not support empty arguments",
fun
@@ -514,20 +514,22 @@ pub fn return_type(
}
// verify that this is a valid set of data types for this function
- data_types(arg_types, &signature(fun))?;
+ data_types(input_expr_types, &signature(fun))?;
// the return type of the built in function.
// Some built-in functions' return type depends on the incoming type.
match fun {
BuiltinScalarFunction::Array => Ok(DataType::FixedSizeList(
- Box::new(Field::new("item", arg_types[0].clone(), true)),
- arg_types.len() as i32,
+ Box::new(Field::new("item", input_expr_types[0].clone(), true)),
+ input_expr_types.len() as i32,
)),
BuiltinScalarFunction::Ascii => Ok(DataType::Int32),
- BuiltinScalarFunction::BitLength => utf8_to_int_type(&arg_types[0], "bit_length"),
- BuiltinScalarFunction::Btrim => utf8_to_str_type(&arg_types[0], "btrim"),
+ BuiltinScalarFunction::BitLength => {
+ utf8_to_int_type(&input_expr_types[0], "bit_length")
+ }
+ BuiltinScalarFunction::Btrim => utf8_to_str_type(&input_expr_types[0], "btrim"),
BuiltinScalarFunction::CharacterLength => {
- utf8_to_int_type(&arg_types[0], "character_length")
+ utf8_to_int_type(&input_expr_types[0], "character_length")
}
BuiltinScalarFunction::Chr => Ok(DataType::Utf8),
BuiltinScalarFunction::Concat => Ok(DataType::Utf8),
@@ -536,40 +538,58 @@ pub fn return_type(
BuiltinScalarFunction::DateTrunc => {
Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
}
- BuiltinScalarFunction::InitCap => utf8_to_str_type(&arg_types[0], "initcap"),
- BuiltinScalarFunction::Left => utf8_to_str_type(&arg_types[0], "left"),
- BuiltinScalarFunction::Lower => utf8_to_str_type(&arg_types[0], "lower"),
- BuiltinScalarFunction::Lpad => utf8_to_str_type(&arg_types[0], "lpad"),
- BuiltinScalarFunction::Ltrim => utf8_to_str_type(&arg_types[0], "ltrim"),
- BuiltinScalarFunction::MD5 => utf8_to_str_type(&arg_types[0], "md5"),
+ BuiltinScalarFunction::InitCap => {
+ utf8_to_str_type(&input_expr_types[0], "initcap")
+ }
+ BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"),
+ BuiltinScalarFunction::Lower => utf8_to_str_type(&input_expr_types[0], "lower"),
+ BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"),
+ BuiltinScalarFunction::Ltrim => utf8_to_str_type(&input_expr_types[0], "ltrim"),
+ BuiltinScalarFunction::MD5 => utf8_to_str_type(&input_expr_types[0], "md5"),
BuiltinScalarFunction::NullIf => {
// NULLIF has two args and they might get coerced, get a preview of this
- let coerced_types = data_types(arg_types, &signature(fun));
+ let coerced_types = data_types(input_expr_types, &signature(fun));
coerced_types.map(|typs| typs[0].clone())
}
BuiltinScalarFunction::OctetLength => {
- utf8_to_int_type(&arg_types[0], "octet_length")
+ utf8_to_int_type(&input_expr_types[0], "octet_length")
}
BuiltinScalarFunction::Random => Ok(DataType::Float64),
BuiltinScalarFunction::RegexpReplace => {
- utf8_to_str_type(&arg_types[0], "regex_replace")
+ utf8_to_str_type(&input_expr_types[0], "regex_replace")
+ }
+ BuiltinScalarFunction::Repeat => utf8_to_str_type(&input_expr_types[0], "repeat"),
+ BuiltinScalarFunction::Replace => {
+ utf8_to_str_type(&input_expr_types[0], "replace")
+ }
+ BuiltinScalarFunction::Reverse => {
+ utf8_to_str_type(&input_expr_types[0], "reverse")
+ }
+ BuiltinScalarFunction::Right => utf8_to_str_type(&input_expr_types[0], "right"),
+ BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"),
+ BuiltinScalarFunction::Rtrim => utf8_to_str_type(&input_expr_types[0], "rtrimp"),
+ BuiltinScalarFunction::SHA224 => {
+ utf8_to_binary_type(&input_expr_types[0], "sha224")
+ }
+ BuiltinScalarFunction::SHA256 => {
+ utf8_to_binary_type(&input_expr_types[0], "sha256")
+ }
+ BuiltinScalarFunction::SHA384 => {
+ utf8_to_binary_type(&input_expr_types[0], "sha384")
+ }
+ BuiltinScalarFunction::SHA512 => {
+ utf8_to_binary_type(&input_expr_types[0], "sha512")
+ }
+ BuiltinScalarFunction::Digest => {
+ utf8_to_binary_type(&input_expr_types[0], "digest")
+ }
+ BuiltinScalarFunction::SplitPart => {
+ utf8_to_str_type(&input_expr_types[0], "split_part")
}
- BuiltinScalarFunction::Repeat => utf8_to_str_type(&arg_types[0], "repeat"),
- BuiltinScalarFunction::Replace => utf8_to_str_type(&arg_types[0], "replace"),
- BuiltinScalarFunction::Reverse => utf8_to_str_type(&arg_types[0], "reverse"),
- BuiltinScalarFunction::Right => utf8_to_str_type(&arg_types[0], "right"),
- BuiltinScalarFunction::Rpad => utf8_to_str_type(&arg_types[0], "rpad"),
- BuiltinScalarFunction::Rtrim => utf8_to_str_type(&arg_types[0], "rtrimp"),
- BuiltinScalarFunction::SHA224 => utf8_to_binary_type(&arg_types[0], "sha224"),
- BuiltinScalarFunction::SHA256 => utf8_to_binary_type(&arg_types[0], "sha256"),
- BuiltinScalarFunction::SHA384 => utf8_to_binary_type(&arg_types[0], "sha384"),
- BuiltinScalarFunction::SHA512 => utf8_to_binary_type(&arg_types[0], "sha512"),
- BuiltinScalarFunction::Digest => utf8_to_binary_type(&arg_types[0], "digest"),
- BuiltinScalarFunction::SplitPart => utf8_to_str_type(&arg_types[0], "split_part"),
BuiltinScalarFunction::StartsWith => Ok(DataType::Boolean),
- BuiltinScalarFunction::Strpos => utf8_to_int_type(&arg_types[0], "strpos"),
- BuiltinScalarFunction::Substr => utf8_to_str_type(&arg_types[0], "substr"),
- BuiltinScalarFunction::ToHex => Ok(match arg_types[0] {
+ BuiltinScalarFunction::Strpos => utf8_to_int_type(&input_expr_types[0], "strpos"),
+ BuiltinScalarFunction::Substr => utf8_to_str_type(&input_expr_types[0], "substr"),
+ BuiltinScalarFunction::ToHex => Ok(match input_expr_types[0] {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
DataType::Utf8
}
@@ -593,10 +613,12 @@ pub fn return_type(
Ok(DataType::Timestamp(TimeUnit::Second, None))
}
BuiltinScalarFunction::Now => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)),
- BuiltinScalarFunction::Translate => utf8_to_str_type(&arg_types[0], "translate"),
- BuiltinScalarFunction::Trim => utf8_to_str_type(&arg_types[0], "trim"),
- BuiltinScalarFunction::Upper => utf8_to_str_type(&arg_types[0], "upper"),
- BuiltinScalarFunction::RegexpMatch => Ok(match arg_types[0] {
+ BuiltinScalarFunction::Translate => {
+ utf8_to_str_type(&input_expr_types[0], "translate")
+ }
+ BuiltinScalarFunction::Trim => utf8_to_str_type(&input_expr_types[0], "trim"),
+ BuiltinScalarFunction::Upper => utf8_to_str_type(&input_expr_types[0], "upper"),
+ BuiltinScalarFunction::RegexpMatch => Ok(match input_expr_types[0] {
DataType::LargeUtf8 => {
DataType::List(Box::new(Field::new("item", DataType::LargeUtf8, true)))
}
@@ -628,7 +650,7 @@ pub fn return_type(
| BuiltinScalarFunction::Sin
| BuiltinScalarFunction::Sqrt
| BuiltinScalarFunction::Tan
- | BuiltinScalarFunction::Trunc => match arg_types[0] {
+ | BuiltinScalarFunction::Trunc => match input_expr_types[0] {
DataType::Float32 => Ok(DataType::Float32),
_ => Ok(DataType::Float64),
},
@@ -1130,18 +1152,18 @@ pub fn create_physical_fun(
/// This function errors when `args`' can't be coerced to a valid argument type of the function.
pub fn create_physical_expr(
fun: &BuiltinScalarFunction,
- args: &[Arc<dyn PhysicalExpr>],
+ input_phy_exprs: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,
ctx_state: &ExecutionContextState,
) -> Result<Arc<dyn PhysicalExpr>> {
- let args = coerce(args, input_schema, &signature(fun))?;
+ let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &signature(fun))?;
- let arg_types = args
+ let coerced_expr_types = coerced_phy_exprs
.iter()
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;
- let data_type = return_type(fun, &arg_types)?;
+ let data_type = return_type(fun, &coerced_expr_types)?;
let fun_expr: ScalarFunctionImplementation = match fun {
// These functions need args and input schema to pick an implementation
@@ -1149,7 +1171,7 @@ pub fn create_physical_expr(
// here we return either a cast fn or string timestamp translation based on the expression data type
// so we don't have to pay a per-array/batch cost.
BuiltinScalarFunction::ToTimestamp => {
- Arc::new(match args[0].data_type(input_schema) {
+ Arc::new(match coerced_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
@@ -1169,7 +1191,7 @@ pub fn create_physical_expr(
})
}
BuiltinScalarFunction::ToTimestampMillis => {
- Arc::new(match args[0].data_type(input_schema) {
+ Arc::new(match coerced_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
@@ -1189,7 +1211,7 @@ pub fn create_physical_expr(
})
}
BuiltinScalarFunction::ToTimestampMicros => {
- Arc::new(match args[0].data_type(input_schema) {
+ Arc::new(match coerced_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
@@ -1209,7 +1231,7 @@ pub fn create_physical_expr(
})
}
BuiltinScalarFunction::ToTimestampSeconds => Arc::new({
- match args[0].data_type(input_schema) {
+ match coerced_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
@@ -1235,7 +1257,7 @@ pub fn create_physical_expr(
Ok(Arc::new(ScalarFunctionExpr::new(
&format!("{}", fun),
fun_expr,
- args,
+ coerced_phy_exprs,
&data_type,
)))
}
diff --git a/datafusion/src/physical_plan/type_coercion.rs b/datafusion/src/physical_plan/type_coercion.rs
index 801a83b..b413356 100644
--- a/datafusion/src/physical_plan/type_coercion.rs
+++ b/datafusion/src/physical_plan/type_coercion.rs
@@ -103,11 +103,11 @@ fn get_valid_types(
current_types: &[DataType],
) -> Result<Vec<Vec<DataType>>> {
let valid_types = match signature {
- TypeSignature::Variadic(valid_types, ..) => valid_types
+ TypeSignature::Variadic(valid_types) => valid_types
.iter()
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
- TypeSignature::Uniform(number, valid_types, ..) => valid_types
+ TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect(),
@@ -118,8 +118,8 @@ fn get_valid_types(
.map(|_| current_types[0].clone())
.collect()]
}
- TypeSignature::Exact(valid_types, ..) => vec![valid_types.clone()],
- TypeSignature::Any(number, ..) => {
+ TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
+ TypeSignature::Any(number) => {
if current_types.len() != *number {
return Err(DataFusionError::Plan(format!(
"The function expected {} arguments but received {}",
@@ -129,7 +129,7 @@ fn get_valid_types(
}
vec![(0..*number).map(|i| current_types[i].clone()).collect()]
}
- TypeSignature::OneOf(types, ..) => types
+ TypeSignature::OneOf(types) => types
.iter()
.filter_map(|t| get_valid_types(t, current_types).ok())
.flatten()
diff --git a/datafusion/src/physical_plan/udaf.rs b/datafusion/src/physical_plan/udaf.rs
index d9d1404..08ea5d3 100644
--- a/datafusion/src/physical_plan/udaf.rs
+++ b/datafusion/src/physical_plan/udaf.rs
@@ -114,22 +114,22 @@ impl AggregateUDF {
/// This function errors when `args`' can't be coerced to a valid argument type of the UDAF.
pub fn create_aggregate_expr(
fun: &AggregateUDF,
- args: &[Arc<dyn PhysicalExpr>],
+ input_phy_exprs: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,
name: impl Into<String>,
) -> Result<Arc<dyn AggregateExpr>> {
// coerce
- let args = coerce(args, input_schema, &fun.signature)?;
+ let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &fun.signature)?;
- let arg_types = args
+ let coerced_exprs_types = coerced_phy_exprs
.iter()
.map(|arg| arg.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;
Ok(Arc::new(AggregateFunctionExpr {
fun: fun.clone(),
- args: args.clone(),
- data_type: (fun.return_type)(&arg_types)?.as_ref().clone(),
+ args: coerced_phy_exprs.clone(),
+ data_type: (fun.return_type)(&coerced_exprs_types)?.as_ref().clone(),
name: name.into(),
}))
}
diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs
index 39265c0..0c5e80b 100644
--- a/datafusion/src/physical_plan/udf.rs
+++ b/datafusion/src/physical_plan/udf.rs
@@ -110,13 +110,13 @@ impl ScalarUDF {
/// This function errors when `args`' can't be coerced to a valid argument type of the UDF.
pub fn create_physical_expr(
fun: &ScalarUDF,
- args: &[Arc<dyn PhysicalExpr>],
+ input_phy_exprs: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
// coerce
- let args = coerce(args, input_schema, &fun.signature)?;
+ let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &fun.signature)?;
- let arg_types = args
+ let coerced_exprs_types = coerced_phy_exprs
.iter()
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;
@@ -124,7 +124,7 @@ pub fn create_physical_expr(
Ok(Arc::new(ScalarFunctionExpr::new(
&fun.name,
fun.fun.clone(),
- args,
- (fun.return_type)(&arg_types)?.as_ref(),
+ coerced_phy_exprs,
+ (fun.return_type)(&coerced_exprs_types)?.as_ref(),
)))
}
diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs
index 9070ca8..0cee845 100644
--- a/datafusion/src/physical_plan/window_functions.rs
+++ b/datafusion/src/physical_plan/window_functions.rs
@@ -149,11 +149,16 @@ impl FromStr for BuiltInWindowFunction {
}
/// Returns the datatype of the window function
-pub fn return_type(fun: &WindowFunction, arg_types: &[DataType]) -> Result<DataType> {
+pub fn return_type(
+ fun: &WindowFunction,
+ input_expr_types: &[DataType],
+) -> Result<DataType> {
match fun {
- WindowFunction::AggregateFunction(fun) => aggregates::return_type(fun, arg_types),
+ WindowFunction::AggregateFunction(fun) => {
+ aggregates::return_type(fun, input_expr_types)
+ }
WindowFunction::BuiltInWindowFunction(fun) => {
- return_type_for_built_in(fun, arg_types)
+ return_type_for_built_in(fun, input_expr_types)
}
}
}
@@ -161,13 +166,13 @@ pub fn return_type(fun: &WindowFunction, arg_types: &[DataType]) -> Result<DataT
/// Returns the datatype of the built-in window function
pub(super) fn return_type_for_built_in(
fun: &BuiltInWindowFunction,
- arg_types: &[DataType],
+ input_expr_types: &[DataType],
) -> Result<DataType> {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.
// verify that this is a valid set of data types for this function
- data_types(arg_types, &signature_for_built_in(fun))?;
+ data_types(input_expr_types, &signature_for_built_in(fun))?;
match fun {
BuiltInWindowFunction::RowNumber
@@ -181,7 +186,7 @@ pub(super) fn return_type_for_built_in(
| BuiltInWindowFunction::Lead
| BuiltInWindowFunction::FirstValue
| BuiltInWindowFunction::LastValue
- | BuiltInWindowFunction::NthValue => Ok(arg_types[0].clone()),
+ | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()),
}
}