You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/10/06 10:33:32 UTC
[arrow-datafusion] branch master updated: Consolidate coercion code in `datafusion_expr::type_coercion` and submodules (#3728)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 7c5c2e5e3 Consolidate coercion code in `datafusion_expr::type_coercion` and submodules (#3728)
7c5c2e5e3 is described below
commit 7c5c2e5e399cf4b20527864385fc5de643d12021
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Thu Oct 6 06:33:25 2022 -0400
Consolidate coercion code in `datafusion_expr::type_coercion` and submodules (#3728)
* Move function coercion to its own module
* Move binary into type coercion
* fmt
* More updates
* consolidate some more
* Move aggregates
---
datafusion/expr/src/aggregate_function.rs | 698 +--------------------
datafusion/expr/src/expr_schema.rs | 2 +-
datafusion/expr/src/function.rs | 2 +-
datafusion/expr/src/lib.rs | 1 -
datafusion/expr/src/logical_plan/builder.rs | 2 +-
datafusion/expr/src/type_coercion.rs | 266 +-------
.../aggregates.rs} | 328 ++--------
.../{binary_rule.rs => type_coercion/binary.rs} | 25 +-
.../functions.rs} | 22 +-
datafusion/expr/src/type_coercion/other.rs | 57 ++
datafusion/expr/src/window_function.rs | 2 +-
datafusion/optimizer/src/type_coercion.rs | 44 +-
datafusion/physical-expr/src/aggregate/build_in.rs | 2 +-
.../physical-expr/src/aggregate/coercion_rule.rs | 4 +-
datafusion/physical-expr/src/expressions/binary.rs | 4 +-
datafusion/physical-expr/src/expressions/case.rs | 2 +-
.../physical-expr/src/expressions/in_list.rs | 2 +-
.../physical-expr/src/expressions/negative.rs | 2 +-
datafusion/physical-expr/src/functions.rs | 5 +-
datafusion/physical-expr/src/type_coercion.rs | 2 +-
20 files changed, 188 insertions(+), 1284 deletions(-)
diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs
index 7b8616921..ce03a252e 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/aggregate_function.rs
@@ -17,45 +17,11 @@
//! Aggregate function module contains all built-in aggregate functions definitions
-use crate::{Signature, TypeSignature, Volatility};
-use arrow::datatypes::{
- DataType, Field, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
-};
+use crate::{type_coercion::aggregates::*, Signature, TypeSignature, Volatility};
+use arrow::datatypes::{DataType, Field};
use datafusion_common::{DataFusionError, Result};
-use std::ops::Deref;
use std::{fmt, str::FromStr};
-pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8];
-
-pub static NUMERICS: &[DataType] = &[
- DataType::Int8,
- DataType::Int16,
- DataType::Int32,
- DataType::Int64,
- DataType::UInt8,
- DataType::UInt16,
- DataType::UInt32,
- DataType::UInt64,
- DataType::Float32,
- DataType::Float64,
-];
-
-pub static TIMESTAMPS: &[DataType] = &[
- DataType::Timestamp(TimeUnit::Second, None),
- DataType::Timestamp(TimeUnit::Millisecond, None),
- DataType::Timestamp(TimeUnit::Microsecond, None),
- DataType::Timestamp(TimeUnit::Nanosecond, None),
-];
-
-pub static DATES: &[DataType] = &[DataType::Date32, DataType::Date64];
-
-pub static TIMES: &[DataType] = &[
- DataType::Time32(TimeUnit::Second),
- DataType::Time32(TimeUnit::Millisecond),
- DataType::Time64(TimeUnit::Microsecond),
- DataType::Time64(TimeUnit::Nanosecond),
-];
-
/// Enum of all built-in aggregate functions
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum AggregateFunction {
@@ -154,7 +120,11 @@ pub fn return_type(
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.
- let coerced_data_types = coerce_types(fun, input_expr_types, &signature(fun))?;
+ let coerced_data_types = crate::type_coercion::aggregates::coerce_types(
+ fun,
+ input_expr_types,
+ &signature(fun),
+ )?;
match fun {
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
@@ -192,167 +162,6 @@ pub fn return_type(
}
}
-/// Returns the coerced data type for each `input_types`.
-/// Different aggregate function with different input data type will get corresponding coerced data type.
-pub fn coerce_types(
- agg_fun: &AggregateFunction,
- input_types: &[DataType],
- signature: &Signature,
-) -> Result<Vec<DataType>> {
- // Validate input_types matches (at least one of) the func signature.
- check_arg_count(agg_fun, input_types, &signature.type_signature)?;
-
- match agg_fun {
- AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
- Ok(input_types.to_vec())
- }
- AggregateFunction::ArrayAgg => Ok(input_types.to_vec()),
- AggregateFunction::Min | AggregateFunction::Max => {
- // min and max support the dictionary data type
- // unpack the dictionary to get the value
- get_min_max_result_type(input_types)
- }
- AggregateFunction::Sum => {
- // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
- // smallint, int, bigint, real, double precision, decimal, or interval.
- if !is_sum_support_arg_type(&input_types[0]) {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun, input_types[0]
- )));
- }
- Ok(input_types.to_vec())
- }
- AggregateFunction::Avg => {
- // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
- // smallint, int, bigint, real, double precision, decimal, or interval
- if !is_avg_support_arg_type(&input_types[0]) {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun, input_types[0]
- )));
- }
- Ok(input_types.to_vec())
- }
- AggregateFunction::Variance => {
- if !is_variance_support_arg_type(&input_types[0]) {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun, input_types[0]
- )));
- }
- Ok(input_types.to_vec())
- }
- AggregateFunction::VariancePop => {
- if !is_variance_support_arg_type(&input_types[0]) {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun, input_types[0]
- )));
- }
- Ok(input_types.to_vec())
- }
- AggregateFunction::Covariance => {
- if !is_covariance_support_arg_type(&input_types[0]) {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun, input_types[0]
- )));
- }
- Ok(input_types.to_vec())
- }
- AggregateFunction::CovariancePop => {
- if !is_covariance_support_arg_type(&input_types[0]) {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun, input_types[0]
- )));
- }
- Ok(input_types.to_vec())
- }
- AggregateFunction::Stddev => {
- if !is_stddev_support_arg_type(&input_types[0]) {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun, input_types[0]
- )));
- }
- Ok(input_types.to_vec())
- }
- AggregateFunction::StddevPop => {
- if !is_stddev_support_arg_type(&input_types[0]) {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun, input_types[0]
- )));
- }
- Ok(input_types.to_vec())
- }
- AggregateFunction::Correlation => {
- if !is_correlation_support_arg_type(&input_types[0]) {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun, input_types[0]
- )));
- }
- Ok(input_types.to_vec())
- }
- AggregateFunction::ApproxPercentileCont => {
- if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun, input_types[0]
- )));
- }
- if !matches!(input_types[1], DataType::Float64) {
- return Err(DataFusionError::Plan(format!(
- "The percentile argument for {:?} must be Float64, not {:?}.",
- agg_fun, input_types[1]
- )));
- }
- if input_types.len() == 3 && !is_integer_arg_type(&input_types[2]) {
- return Err(DataFusionError::Plan(format!(
- "The percentile sample points count for {:?} must be integer, not {:?}.",
- agg_fun, input_types[2]
- )));
- }
- Ok(input_types.to_vec())
- }
- AggregateFunction::ApproxPercentileContWithWeight => {
- if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun, input_types[0]
- )));
- }
- if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) {
- return Err(DataFusionError::Plan(format!(
- "The weight argument for {:?} does not support inputs of type {:?}.",
- agg_fun, input_types[1]
- )));
- }
- if !matches!(input_types[2], DataType::Float64) {
- return Err(DataFusionError::Plan(format!(
- "The percentile argument for {:?} must be Float64, not {:?}.",
- agg_fun, input_types[2]
- )));
- }
- Ok(input_types.to_vec())
- }
- AggregateFunction::ApproxMedian => {
- if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not support inputs of type {:?}.",
- agg_fun, input_types[0]
- )));
- }
- Ok(input_types.to_vec())
- }
- AggregateFunction::Median => Ok(input_types.to_vec()),
- AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
- }
-}
-
/// the signatures supported by the function `fun`.
pub fn signature(fun: &AggregateFunction) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
@@ -414,496 +223,3 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
),
}
}
-
-/// function return type of a sum
-pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
- match arg_type {
- DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
- Ok(DataType::Int64)
- }
- DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
- Ok(DataType::UInt64)
- }
- // In the https://www.postgresql.org/docs/current/functions-aggregate.html doc,
- // the result type of floating-point is FLOAT64 with the double precision.
- DataType::Float64 | DataType::Float32 => Ok(DataType::Float64),
- DataType::Decimal128(precision, scale) => {
- // in the spark, the result type is DECIMAL(min(38,precision+10), s)
- // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
- let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
- Ok(DataType::Decimal128(new_precision, *scale))
- }
- other => Err(DataFusionError::Plan(format!(
- "SUM does not support type \"{:?}\"",
- other
- ))),
- }
-}
-
-/// function return type of variance
-pub fn variance_return_type(arg_type: &DataType) -> Result<DataType> {
- match arg_type {
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Float32
- | DataType::Float64 => Ok(DataType::Float64),
- other => Err(DataFusionError::Plan(format!(
- "VAR does not support {:?}",
- other
- ))),
- }
-}
-
-/// function return type of covariance
-pub fn covariance_return_type(arg_type: &DataType) -> Result<DataType> {
- match arg_type {
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Float32
- | DataType::Float64 => Ok(DataType::Float64),
- other => Err(DataFusionError::Plan(format!(
- "COVAR does not support {:?}",
- other
- ))),
- }
-}
-
-/// function return type of correlation
-pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> {
- match arg_type {
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Float32
- | DataType::Float64 => Ok(DataType::Float64),
- other => Err(DataFusionError::Plan(format!(
- "CORR does not support {:?}",
- other
- ))),
- }
-}
-
-/// function return type of standard deviation
-pub fn stddev_return_type(arg_type: &DataType) -> Result<DataType> {
- match arg_type {
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Float32
- | DataType::Float64 => Ok(DataType::Float64),
- other => Err(DataFusionError::Plan(format!(
- "STDDEV does not support {:?}",
- other
- ))),
- }
-}
-
-/// function return type of an average
-pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
- match arg_type {
- DataType::Decimal128(precision, scale) => {
- // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
- // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
- let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
- let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
- Ok(DataType::Decimal128(new_precision, new_scale))
- }
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Float32
- | DataType::Float64 => Ok(DataType::Float64),
- other => Err(DataFusionError::Plan(format!(
- "AVG does not support {:?}",
- other
- ))),
- }
-}
-
-/// Validate the length of `input_types` matches the `signature` for `agg_fun`.
-///
-/// This method DOES NOT validate the argument types - only that (at least one,
-/// in the case of [`TypeSignature::OneOf`]) signature matches the desired
-/// number of input types.
-fn check_arg_count(
- agg_fun: &AggregateFunction,
- input_types: &[DataType],
- signature: &TypeSignature,
-) -> Result<()> {
- match signature {
- TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
- if input_types.len() != *agg_count {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} expects {:?} arguments, but {:?} were provided",
- agg_fun,
- agg_count,
- input_types.len()
- )));
- }
- }
- TypeSignature::Exact(types) => {
- if types.len() != input_types.len() {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} expects {:?} arguments, but {:?} were provided",
- agg_fun,
- types.len(),
- input_types.len()
- )));
- }
- }
- TypeSignature::OneOf(variants) => {
- let ok = variants
- .iter()
- .any(|v| check_arg_count(agg_fun, input_types, v).is_ok());
- if !ok {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not accept {:?} function arguments.",
- agg_fun,
- input_types.len()
- )));
- }
- }
- _ => {
- return Err(DataFusionError::Internal(format!(
- "Aggregate functions do not support this {:?}",
- signature
- )));
- }
- }
- Ok(())
-}
-
-fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
- // make sure that the input types only has one element.
- assert_eq!(input_types.len(), 1);
- // min and max support the dictionary data type
- // unpack the dictionary to get the value
- match &input_types[0] {
- DataType::Dictionary(_, dict_value_type) => {
- // TODO add checker, if the value type is complex data type
- Ok(vec![dict_value_type.deref().clone()])
- }
- // TODO add checker for datatype which min and max supported
- // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function
- _ => Ok(input_types.to_vec()),
- }
-}
-
-pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
- matches!(
- arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64
- | DataType::Decimal128(_, _)
- )
-}
-
-pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
- matches!(
- arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64
- | DataType::Decimal128(_, _)
- )
-}
-
-pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool {
- matches!(
- arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64
- )
-}
-
-pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool {
- matches!(
- arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64
- )
-}
-
-pub fn is_stddev_support_arg_type(arg_type: &DataType) -> bool {
- matches!(
- arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64
- )
-}
-
-pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool {
- matches!(
- arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64
- )
-}
-
-pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
- matches!(
- arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- )
-}
-
-/// Return `true` if `arg_type` is of a [`DataType`] that the
-/// [`AggregateFunction::ApproxPercentileCont`] aggregation can operate on.
-pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool {
- matches!(
- arg_type,
- DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float32
- | DataType::Float64
- )
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use crate::aggregate_function;
- use arrow::datatypes::DataType;
-
- #[test]
- fn test_aggregate_coerce_types() {
- // test input args with error number input types
- let fun = AggregateFunction::Min;
- let input_types = vec![DataType::Int64, DataType::Int32];
- let signature = aggregate_function::signature(&fun);
- let result = coerce_types(&fun, &input_types, &signature);
- assert_eq!("Error during planning: The function Min expects 1 arguments, but 2 were provided", result.unwrap_err().to_string());
-
- // test input args is invalid data type for sum or avg
- let fun = AggregateFunction::Sum;
- let input_types = vec![DataType::Utf8];
- let signature = aggregate_function::signature(&fun);
- let result = coerce_types(&fun, &input_types, &signature);
- assert_eq!(
- "Error during planning: The function Sum does not support inputs of type Utf8.",
- result.unwrap_err().to_string()
- );
- let fun = AggregateFunction::Avg;
- let signature = aggregate_function::signature(&fun);
- let result = coerce_types(&fun, &input_types, &signature);
- assert_eq!(
- "Error during planning: The function Avg does not support inputs of type Utf8.",
- result.unwrap_err().to_string()
- );
-
- // test count, array_agg, approx_distinct, min, max.
- // the coerced types is same with input types
- let funs = vec![
- AggregateFunction::Count,
- AggregateFunction::ArrayAgg,
- AggregateFunction::ApproxDistinct,
- AggregateFunction::Min,
- AggregateFunction::Max,
- ];
- let input_types = vec![
- vec![DataType::Int32],
- vec![DataType::Decimal128(10, 2)],
- vec![DataType::Utf8],
- ];
- for fun in funs {
- for input_type in &input_types {
- let signature = aggregate_function::signature(&fun);
- let result = coerce_types(&fun, input_type, &signature);
- assert_eq!(*input_type, result.unwrap());
- }
- }
- // test sum, avg
- let funs = vec![AggregateFunction::Sum, AggregateFunction::Avg];
- let input_types = vec![
- vec![DataType::Int32],
- vec![DataType::Float32],
- vec![DataType::Decimal128(20, 3)],
- ];
- for fun in funs {
- for input_type in &input_types {
- let signature = aggregate_function::signature(&fun);
- let result = coerce_types(&fun, input_type, &signature);
- assert_eq!(*input_type, result.unwrap());
- }
- }
-
- // ApproxPercentileCont input types
- let input_types = vec![
- vec![DataType::Int8, DataType::Float64],
- vec![DataType::Int16, DataType::Float64],
- vec![DataType::Int32, DataType::Float64],
- vec![DataType::Int64, DataType::Float64],
- vec![DataType::UInt8, DataType::Float64],
- vec![DataType::UInt16, DataType::Float64],
- vec![DataType::UInt32, DataType::Float64],
- vec![DataType::UInt64, DataType::Float64],
- vec![DataType::Float32, DataType::Float64],
- vec![DataType::Float64, DataType::Float64],
- ];
- for input_type in &input_types {
- let signature =
- aggregate_function::signature(&AggregateFunction::ApproxPercentileCont);
- let result = coerce_types(
- &AggregateFunction::ApproxPercentileCont,
- input_type,
- &signature,
- );
- assert_eq!(*input_type, result.unwrap());
- }
- }
-
- #[test]
- fn test_avg_return_data_type() -> Result<()> {
- let data_type = DataType::Decimal128(10, 5);
- let result_type = avg_return_type(&data_type)?;
- assert_eq!(DataType::Decimal128(14, 9), result_type);
-
- let data_type = DataType::Decimal128(36, 10);
- let result_type = avg_return_type(&data_type)?;
- assert_eq!(DataType::Decimal128(38, 14), result_type);
- Ok(())
- }
-
- #[test]
- fn test_variance_return_data_type() -> Result<()> {
- let data_type = DataType::Float64;
- let result_type = variance_return_type(&data_type)?;
- assert_eq!(DataType::Float64, result_type);
-
- let data_type = DataType::Decimal128(36, 10);
- assert!(variance_return_type(&data_type).is_err());
- Ok(())
- }
-
- #[test]
- fn test_sum_return_data_type() -> Result<()> {
- let data_type = DataType::Decimal128(10, 5);
- let result_type = sum_return_type(&data_type)?;
- assert_eq!(DataType::Decimal128(20, 5), result_type);
-
- let data_type = DataType::Decimal128(36, 10);
- let result_type = sum_return_type(&data_type)?;
- assert_eq!(DataType::Decimal128(38, 10), result_type);
- Ok(())
- }
-
- #[test]
- fn test_stddev_return_data_type() -> Result<()> {
- let data_type = DataType::Float64;
- let result_type = stddev_return_type(&data_type)?;
- assert_eq!(DataType::Float64, result_type);
-
- let data_type = DataType::Decimal128(36, 10);
- assert!(stddev_return_type(&data_type).is_err());
- Ok(())
- }
-
- #[test]
- fn test_covariance_return_data_type() -> Result<()> {
- let data_type = DataType::Float64;
- let result_type = covariance_return_type(&data_type)?;
- assert_eq!(DataType::Float64, result_type);
-
- let data_type = DataType::Decimal128(36, 10);
- assert!(covariance_return_type(&data_type).is_err());
- Ok(())
- }
-
- #[test]
- fn test_correlation_return_data_type() -> Result<()> {
- let data_type = DataType::Float64;
- let result_type = correlation_return_type(&data_type)?;
- assert_eq!(DataType::Float64, result_type);
-
- let data_type = DataType::Decimal128(36, 10);
- assert!(correlation_return_type(&data_type).is_err());
- Ok(())
- }
-}
diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs
index 7ec4eddf1..88d767366 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -16,8 +16,8 @@
// under the License.
use super::Expr;
-use crate::binary_rule::binary_operator_data_type;
use crate::field_util::get_indexed_field;
+use crate::type_coercion::binary::binary_operator_data_type;
use crate::{aggregate_function, function, window_function};
use arrow::compute::can_cast_types;
use arrow::datatypes::DataType;
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index 973ea8f83..e39339c20 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -18,7 +18,7 @@
//! Function module contains typing and signature for built-in and user defined functions.
use crate::nullif::SUPPORTED_NULLIF_TYPES;
-use crate::type_coercion::data_types;
+use crate::type_coercion::functions::data_types;
use crate::ColumnarValue;
use crate::{
array_expressions, conditional_expressions, struct_expressions, Accumulator,
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index 90007a8bd..16b1a8f13 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -28,7 +28,6 @@
mod accumulator;
pub mod aggregate_function;
pub mod array_expressions;
-pub mod binary_rule;
mod built_in_function;
mod columnar_value;
pub mod conditional_expressions;
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index e881b05ed..3850871b0 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -17,10 +17,10 @@
//! This module provides a builder for creating LogicalPlans
-use crate::binary_rule::comparison_coercion;
use crate::expr_rewriter::{
coerce_plan_expr_for_schema, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs,
};
+use crate::type_coercion::binary::comparison_coercion;
use crate::utils::{
columnize_expr, exprlist_to_fields, from_plan, grouping_set_to_exprlist,
};
diff --git a/datafusion/expr/src/type_coercion.rs b/datafusion/expr/src/type_coercion.rs
index 27eee3d30..41eeb3c65 100644
--- a/datafusion/expr/src/type_coercion.rs
+++ b/datafusion/expr/src/type_coercion.rs
@@ -15,247 +15,49 @@
// specific language governing permissions and limitations
// under the License.
-//! Type coercion rules for functions with multiple valid signatures
+//! Type coercion rules for DataFusion
//!
//! Coercion is performed automatically by DataFusion when the types
-//! of arguments passed to a function do not exacty match the types
-//! required by that function. In this case, DataFusion will attempt to
-//! *coerce* the arguments to types accepted by the function by
-//! inserting CAST operations.
+//! of arguments passed to a function or needed by operators do not
+//! exacty match the types required by that function / operator. In
+//! this case, DataFusion will attempt to *coerce* the arguments to
+//! types accepted by the function by inserting CAST operations.
//!
//! CAST operations added by coercion are lossless and never discard
-//! information. For example coercion from i32 -> i64 might be
+//! information.
+//!
+//! For example coercion from i32 -> i64 might be
//! performed because all valid i32 values can be represented using an
//! i64. However, i64 -> i32 is never performed as there are i64
//! values which can not be represented by i32 values.
-//!
-
-use crate::{Signature, TypeSignature};
-use arrow::{
- compute::can_cast_types,
- datatypes::{DataType, TimeUnit},
-};
-use datafusion_common::{DataFusionError, Result};
-
-/// Returns the data types that each argument must be coerced to match
-/// `signature`.
-///
-/// See the module level documentation for more detail on coercion.
-pub fn data_types(
- current_types: &[DataType],
- signature: &Signature,
-) -> Result<Vec<DataType>> {
- if current_types.is_empty() {
- return Ok(vec![]);
- }
- let valid_types = get_valid_types(&signature.type_signature, current_types)?;
-
- if valid_types
- .iter()
- .any(|data_type| data_type == current_types)
- {
- return Ok(current_types.to_vec());
- }
-
- for valid_types in valid_types {
- if let Some(types) = maybe_data_types(&valid_types, current_types) {
- return Ok(types);
- }
- }
- // none possible -> Error
- Err(DataFusionError::Plan(format!(
- "Coercion from {:?} to the signature {:?} failed.",
- current_types, &signature.type_signature
- )))
+use arrow::datatypes::DataType;
+
+/// Determine if a DataType is signed numeric or not
+pub fn is_signed_numeric(dt: &DataType) -> bool {
+ matches!(
+ dt,
+ DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64
+ | DataType::Float16
+ | DataType::Float32
+ | DataType::Float64
+ | DataType::Decimal128(_, _)
+ )
}
-fn get_valid_types(
- signature: &TypeSignature,
- current_types: &[DataType],
-) -> Result<Vec<Vec<DataType>>> {
- let valid_types = match signature {
- TypeSignature::Variadic(valid_types) => valid_types
- .iter()
- .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
- .collect(),
- TypeSignature::Uniform(number, valid_types) => valid_types
- .iter()
- .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
- .collect(),
- TypeSignature::VariadicEqual => {
- // one entry with the same len as current_types, whose type is `current_types[0]`.
- vec![current_types
- .iter()
- .map(|_| current_types[0].clone())
- .collect()]
- }
- TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
- TypeSignature::Any(number) => {
- if current_types.len() != *number {
- return Err(DataFusionError::Plan(format!(
- "The function expected {} arguments but received {}",
- number,
- current_types.len()
- )));
- }
- vec![(0..*number).map(|i| current_types[i].clone()).collect()]
- }
- TypeSignature::OneOf(types) => types
- .iter()
- .filter_map(|t| get_valid_types(t, current_types).ok())
- .flatten()
- .collect::<Vec<_>>(),
- };
-
- Ok(valid_types)
-}
-
-/// Try to coerce current_types into valid_types.
-fn maybe_data_types(
- valid_types: &[DataType],
- current_types: &[DataType],
-) -> Option<Vec<DataType>> {
- if valid_types.len() != current_types.len() {
- return None;
- }
-
- let mut new_type = Vec::with_capacity(valid_types.len());
- for (i, valid_type) in valid_types.iter().enumerate() {
- let current_type = ¤t_types[i];
-
- if current_type == valid_type {
- new_type.push(current_type.clone())
- } else {
- // attempt to coerce
- if can_coerce_from(valid_type, current_type) {
- new_type.push(valid_type.clone())
- } else {
- // not possible
- return None;
- }
- }
- }
- Some(new_type)
+/// Determine if a DataType is numeric or not
+pub fn is_numeric(dt: &DataType) -> bool {
+ is_signed_numeric(dt)
+ || matches!(
+ dt,
+ DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64
+ )
}
-/// Return true if a value of type `type_from` can be coerced
-/// (losslessly converted) into a value of `type_to`
-///
-/// See the module level documentation for more detail on coercion.
-pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
- use self::DataType::*;
- // Null can convert to most of types
- match type_into {
- Int8 => matches!(type_from, Null | Int8),
- Int16 => matches!(type_from, Null | Int8 | Int16 | UInt8),
- Int32 => matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16),
- Int64 => matches!(
- type_from,
- Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32
- ),
- UInt8 => matches!(type_from, Null | UInt8),
- UInt16 => matches!(type_from, Null | UInt8 | UInt16),
- UInt32 => matches!(type_from, Null | UInt8 | UInt16 | UInt32),
- UInt64 => matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64),
- Float32 => matches!(
- type_from,
- Null | Int8
- | Int16
- | Int32
- | Int64
- | UInt8
- | UInt16
- | UInt32
- | UInt64
- | Float32
- ),
- Float64 => matches!(
- type_from,
- Null | Int8
- | Int16
- | Int32
- | Int64
- | UInt8
- | UInt16
- | UInt32
- | UInt64
- | Float32
- | Float64
- | Decimal128(_, _)
- ),
- Timestamp(TimeUnit::Nanosecond, None) => {
- matches!(type_from, Null | Timestamp(_, None))
- }
- Utf8 | LargeUtf8 => true,
- Null => can_cast_types(type_from, type_into),
- _ => false,
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use arrow::datatypes::DataType;
-
- #[test]
- fn test_maybe_data_types() {
- // this vec contains: arg1, arg2, expected result
- let cases = vec![
- // 2 entries, same values
- (
- vec![DataType::UInt8, DataType::UInt16],
- vec![DataType::UInt8, DataType::UInt16],
- Some(vec![DataType::UInt8, DataType::UInt16]),
- ),
- // 2 entries, can coerse values
- (
- vec![DataType::UInt16, DataType::UInt16],
- vec![DataType::UInt8, DataType::UInt16],
- Some(vec![DataType::UInt16, DataType::UInt16]),
- ),
- // 0 entries, all good
- (vec![], vec![], Some(vec![])),
- // 2 entries, can't coerce
- (
- vec![DataType::Boolean, DataType::UInt16],
- vec![DataType::UInt8, DataType::UInt16],
- None,
- ),
- // u32 -> u16 is possible
- (
- vec![DataType::Boolean, DataType::UInt32],
- vec![DataType::Boolean, DataType::UInt16],
- Some(vec![DataType::Boolean, DataType::UInt32]),
- ),
- ];
-
- for case in cases {
- assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
- }
- }
-
- #[test]
- fn test_get_valid_types_one_of() -> Result<()> {
- let signature =
- TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
-
- let invalid_types = get_valid_types(
- &signature,
- &[DataType::Int32, DataType::Int32, DataType::Int32],
- )?;
- assert_eq!(invalid_types.len(), 0);
-
- let args = vec![DataType::Int32, DataType::Int32];
- let valid_types = get_valid_types(&signature, &args)?;
- assert_eq!(valid_types.len(), 1);
- assert_eq!(valid_types[0], args);
-
- let args = vec![DataType::Int32];
- let valid_types = get_valid_types(&signature, &args)?;
- assert_eq!(valid_types.len(), 1);
- assert_eq!(valid_types[0], args);
-
- Ok(())
- }
-}
+pub mod aggregates;
+pub mod binary;
+pub mod functions;
+pub mod other;
diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/type_coercion/aggregates.rs
similarity index 76%
copy from datafusion/expr/src/aggregate_function.rs
copy to datafusion/expr/src/type_coercion/aggregates.rs
index 7b8616921..72d6dc398 100644
--- a/datafusion/expr/src/aggregate_function.rs
+++ b/datafusion/expr/src/type_coercion/aggregates.rs
@@ -15,15 +15,13 @@
// specific language governing permissions and limitations
// under the License.
-//! Aggregate function module contains all built-in aggregate functions definitions
-
-use crate::{Signature, TypeSignature, Volatility};
use arrow::datatypes::{
- DataType, Field, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
+ DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
};
use datafusion_common::{DataFusionError, Result};
use std::ops::Deref;
-use std::{fmt, str::FromStr};
+
+use crate::{AggregateFunction, Signature, TypeSignature};
pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8];
@@ -56,142 +54,6 @@ pub static TIMES: &[DataType] = &[
DataType::Time64(TimeUnit::Nanosecond),
];
-/// Enum of all built-in aggregate functions
-#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
-pub enum AggregateFunction {
- /// count
- Count,
- /// sum
- Sum,
- /// min
- Min,
- /// max
- Max,
- /// avg
- Avg,
- /// median
- Median,
- /// Approximate aggregate function
- ApproxDistinct,
- /// array_agg
- ArrayAgg,
- /// Variance (Sample)
- Variance,
- /// Variance (Population)
- VariancePop,
- /// Standard Deviation (Sample)
- Stddev,
- /// Standard Deviation (Population)
- StddevPop,
- /// Covariance (Sample)
- Covariance,
- /// Covariance (Population)
- CovariancePop,
- /// Correlation
- Correlation,
- /// Approximate continuous percentile function
- ApproxPercentileCont,
- /// Approximate continuous percentile function with weight
- ApproxPercentileContWithWeight,
- /// ApproxMedian
- ApproxMedian,
- /// Grouping
- Grouping,
-}
-
-impl fmt::Display for AggregateFunction {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- // uppercase of the debug.
- write!(f, "{}", format!("{:?}", self).to_uppercase())
- }
-}
-
-impl FromStr for AggregateFunction {
- type Err = DataFusionError;
- fn from_str(name: &str) -> Result<AggregateFunction> {
- Ok(match name {
- "min" => AggregateFunction::Min,
- "max" => AggregateFunction::Max,
- "count" => AggregateFunction::Count,
- "avg" => AggregateFunction::Avg,
- "mean" => AggregateFunction::Avg,
- "sum" => AggregateFunction::Sum,
- "median" => AggregateFunction::Median,
- "approx_distinct" => AggregateFunction::ApproxDistinct,
- "array_agg" => AggregateFunction::ArrayAgg,
- "var" => AggregateFunction::Variance,
- "var_samp" => AggregateFunction::Variance,
- "var_pop" => AggregateFunction::VariancePop,
- "stddev" => AggregateFunction::Stddev,
- "stddev_samp" => AggregateFunction::Stddev,
- "stddev_pop" => AggregateFunction::StddevPop,
- "covar" => AggregateFunction::Covariance,
- "covar_samp" => AggregateFunction::Covariance,
- "covar_pop" => AggregateFunction::CovariancePop,
- "corr" => AggregateFunction::Correlation,
- "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont,
- "approx_percentile_cont_with_weight" => {
- AggregateFunction::ApproxPercentileContWithWeight
- }
- "approx_median" => AggregateFunction::ApproxMedian,
- "grouping" => AggregateFunction::Grouping,
- _ => {
- return Err(DataFusionError::Plan(format!(
- "There is no built-in function named {}",
- name
- )));
- }
- })
- }
-}
-
-/// Returns the datatype of the aggregate function.
-/// This is used to get the returned data type for aggregate expr.
-pub fn return_type(
- fun: &AggregateFunction,
- input_expr_types: &[DataType],
-) -> Result<DataType> {
- // Note that this function *must* return the same type that the respective physical expression returns
- // or the execution panics.
-
- let coerced_data_types = coerce_types(fun, input_expr_types, &signature(fun))?;
-
- match fun {
- AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
- Ok(DataType::Int64)
- }
- AggregateFunction::Max | AggregateFunction::Min => {
- // For min and max agg function, the returned type is same as input type.
- // The coerced_data_types is same with input_types.
- Ok(coerced_data_types[0].clone())
- }
- AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]),
- AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]),
- AggregateFunction::VariancePop => variance_return_type(&coerced_data_types[0]),
- AggregateFunction::Covariance => covariance_return_type(&coerced_data_types[0]),
- AggregateFunction::CovariancePop => {
- covariance_return_type(&coerced_data_types[0])
- }
- AggregateFunction::Correlation => correlation_return_type(&coerced_data_types[0]),
- AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]),
- AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]),
- AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
- AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new(
- "item",
- coerced_data_types[0].clone(),
- true,
- )))),
- AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()),
- AggregateFunction::ApproxPercentileContWithWeight => {
- Ok(coerced_data_types[0].clone())
- }
- AggregateFunction::ApproxMedian | AggregateFunction::Median => {
- Ok(coerced_data_types[0].clone())
- }
- AggregateFunction::Grouping => Ok(DataType::Int32),
- }
-}
-
/// Returns the coerced data type for each `input_types`.
/// Different aggregate function with different input data type will get corresponding coerced data type.
pub fn coerce_types(
@@ -353,65 +215,72 @@ pub fn coerce_types(
}
}
-/// the signatures supported by the function `fun`.
-pub fn signature(fun: &AggregateFunction) -> Signature {
- // note: the physical expression must accept the type returned by this function or the execution panics.
- match fun {
- AggregateFunction::Count
- | AggregateFunction::ApproxDistinct
- | AggregateFunction::Grouping
- | AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable),
- AggregateFunction::Min | AggregateFunction::Max => {
- let valid = STRINGS
- .iter()
- .chain(NUMERICS.iter())
- .chain(TIMESTAMPS.iter())
- .chain(DATES.iter())
- .chain(TIMES.iter())
- .cloned()
- .collect::<Vec<_>>();
- Signature::uniform(1, valid, Volatility::Immutable)
+/// Validate the length of `input_types` matches the `signature` for `agg_fun`.
+///
+/// This method DOES NOT validate the argument types - only that (at least one,
+/// in the case of [`TypeSignature::OneOf`]) signature matches the desired
+/// number of input types.
+fn check_arg_count(
+ agg_fun: &AggregateFunction,
+ input_types: &[DataType],
+ signature: &TypeSignature,
+) -> Result<()> {
+ match signature {
+ TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
+ if input_types.len() != *agg_count {
+ return Err(DataFusionError::Plan(format!(
+ "The function {:?} expects {:?} arguments, but {:?} were provided",
+ agg_fun,
+ agg_count,
+ input_types.len()
+ )));
+ }
}
- AggregateFunction::Avg
- | AggregateFunction::Sum
- | AggregateFunction::Variance
- | AggregateFunction::VariancePop
- | AggregateFunction::Stddev
- | AggregateFunction::StddevPop
- | AggregateFunction::Median
- | AggregateFunction::ApproxMedian => {
- Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
+ TypeSignature::Exact(types) => {
+ if types.len() != input_types.len() {
+ return Err(DataFusionError::Plan(format!(
+ "The function {:?} expects {:?} arguments, but {:?} were provided",
+ agg_fun,
+ types.len(),
+ input_types.len()
+ )));
+ }
}
- AggregateFunction::Covariance | AggregateFunction::CovariancePop => {
- Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
+ TypeSignature::OneOf(variants) => {
+ let ok = variants
+ .iter()
+ .any(|v| check_arg_count(agg_fun, input_types, v).is_ok());
+ if !ok {
+ return Err(DataFusionError::Plan(format!(
+ "The function {:?} does not accept {:?} function arguments.",
+ agg_fun,
+ input_types.len()
+ )));
+ }
}
- AggregateFunction::Correlation => {
- Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
+ _ => {
+ return Err(DataFusionError::Internal(format!(
+ "Aggregate functions do not support this {:?}",
+ signature
+ )));
}
- AggregateFunction::ApproxPercentileCont => {
- // Accept any numeric value paired with a float64 percentile
- let with_tdigest_size = NUMERICS.iter().map(|t| {
- TypeSignature::Exact(vec![t.clone(), DataType::Float64, t.clone()])
- });
- Signature::one_of(
- NUMERICS
- .iter()
- .map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64]))
- .chain(with_tdigest_size)
- .collect(),
- Volatility::Immutable,
- )
+ }
+ Ok(())
+}
+
+fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
+ // make sure that the input types only has one element.
+ assert_eq!(input_types.len(), 1);
+ // min and max support the dictionary data type
+ // unpack the dictionary to get the value
+ match &input_types[0] {
+ DataType::Dictionary(_, dict_value_type) => {
+ // TODO add checker, if the value type is complex data type
+ Ok(vec![dict_value_type.deref().clone()])
}
- AggregateFunction::ApproxPercentileContWithWeight => Signature::one_of(
- // Accept any numeric value paired with a float64 percentile
- NUMERICS
- .iter()
- .map(|t| {
- TypeSignature::Exact(vec![t.clone(), t.clone(), DataType::Float64])
- })
- .collect(),
- Volatility::Immutable,
- ),
+ // TODO add checker for datatype which min and max supported
+ // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function
+ _ => Ok(input_types.to_vec()),
}
}
@@ -547,75 +416,6 @@ pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
}
}
-/// Validate the length of `input_types` matches the `signature` for `agg_fun`.
-///
-/// This method DOES NOT validate the argument types - only that (at least one,
-/// in the case of [`TypeSignature::OneOf`]) signature matches the desired
-/// number of input types.
-fn check_arg_count(
- agg_fun: &AggregateFunction,
- input_types: &[DataType],
- signature: &TypeSignature,
-) -> Result<()> {
- match signature {
- TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
- if input_types.len() != *agg_count {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} expects {:?} arguments, but {:?} were provided",
- agg_fun,
- agg_count,
- input_types.len()
- )));
- }
- }
- TypeSignature::Exact(types) => {
- if types.len() != input_types.len() {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} expects {:?} arguments, but {:?} were provided",
- agg_fun,
- types.len(),
- input_types.len()
- )));
- }
- }
- TypeSignature::OneOf(variants) => {
- let ok = variants
- .iter()
- .any(|v| check_arg_count(agg_fun, input_types, v).is_ok());
- if !ok {
- return Err(DataFusionError::Plan(format!(
- "The function {:?} does not accept {:?} function arguments.",
- agg_fun,
- input_types.len()
- )));
- }
- }
- _ => {
- return Err(DataFusionError::Internal(format!(
- "Aggregate functions do not support this {:?}",
- signature
- )));
- }
- }
- Ok(())
-}
-
-fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
- // make sure that the input types only has one element.
- assert_eq!(input_types.len(), 1);
- // min and max support the dictionary data type
- // unpack the dictionary to get the value
- match &input_types[0] {
- DataType::Dictionary(_, dict_value_type) => {
- // TODO add checker, if the value type is complex data type
- Ok(vec![dict_value_type.deref().clone()])
- }
- // TODO add checker for datatype which min and max supported
- // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function
- _ => Ok(input_types.to_vec()),
- }
-}
-
pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
diff --git a/datafusion/expr/src/binary_rule.rs b/datafusion/expr/src/type_coercion/binary.rs
similarity index 98%
rename from datafusion/expr/src/binary_rule.rs
rename to datafusion/expr/src/type_coercion/binary.rs
index 507ab90c0..2a125d56b 100644
--- a/datafusion/expr/src/binary_rule.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -17,6 +17,7 @@
//! Coercion rules for matching argument types for binary operators
+use crate::type_coercion::is_numeric;
use crate::Operator;
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE};
@@ -421,30 +422,6 @@ fn coercion_decimal_mathematics_type(
}
}
-/// Determine if a DataType is signed numeric or not
-pub fn is_signed_numeric(dt: &DataType) -> bool {
- matches!(
- dt,
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::Float16
- | DataType::Float32
- | DataType::Float64
- | DataType::Decimal128(_, _)
- )
-}
-
-/// Determine if a DataType is numeric or not
-pub fn is_numeric(dt: &DataType) -> bool {
- is_signed_numeric(dt)
- || matches!(
- dt,
- DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64
- )
-}
-
/// Determine if at least of one of lhs and rhs is numeric, and the other must be NULL or numeric
fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> bool {
match (lhs_type, rhs_type) {
diff --git a/datafusion/expr/src/type_coercion.rs b/datafusion/expr/src/type_coercion/functions.rs
similarity index 90%
copy from datafusion/expr/src/type_coercion.rs
copy to datafusion/expr/src/type_coercion/functions.rs
index 27eee3d30..667451a5f 100644
--- a/datafusion/expr/src/type_coercion.rs
+++ b/datafusion/expr/src/type_coercion/functions.rs
@@ -15,21 +15,6 @@
// specific language governing permissions and limitations
// under the License.
-//! Type coercion rules for functions with multiple valid signatures
-//!
-//! Coercion is performed automatically by DataFusion when the types
-//! of arguments passed to a function do not exacty match the types
-//! required by that function. In this case, DataFusion will attempt to
-//! *coerce* the arguments to types accepted by the function by
-//! inserting CAST operations.
-//!
-//! CAST operations added by coercion are lossless and never discard
-//! information. For example coercion from i32 -> i64 might be
-//! performed because all valid i32 values can be represented using an
-//! i64. However, i64 -> i32 is never performed as there are i64
-//! values which can not be represented by i32 values.
-//!
-
use crate::{Signature, TypeSignature};
use arrow::{
compute::can_cast_types,
@@ -37,10 +22,11 @@ use arrow::{
};
use datafusion_common::{DataFusionError, Result};
-/// Returns the data types that each argument must be coerced to match
-/// `signature`.
+/// Performs type coercion for functions Returns the data types that
+/// each argument must be coerced to match `signature`.
///
-/// See the module level documentation for more detail on coercion.
+/// For more details on coercion in general, please see the
+/// [`type_coercion`](datafusion::expr::type_coercion) module.
pub fn data_types(
current_types: &[DataType],
signature: &Signature,
diff --git a/datafusion/expr/src/type_coercion/other.rs b/datafusion/expr/src/type_coercion/other.rs
new file mode 100644
index 000000000..2419f8d1b
--- /dev/null
+++ b/datafusion/expr/src/type_coercion/other.rs
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::datatypes::DataType;
+
+use super::binary::comparison_coercion;
+
+/// Attempts to coerce the types of `list_types` to be comparable with the
+/// `expr_type`.
+/// Returns the common data type for `expr_type` and `list_types`
+pub fn get_coerce_type_for_list(
+ expr_type: &DataType,
+ list_types: &[DataType],
+) -> Option<DataType> {
+ list_types
+ .iter()
+ .fold(Some(expr_type.clone()), |left, right_type| match left {
+ None => None,
+ Some(left_type) => comparison_coercion(&left_type, right_type),
+ })
+}
+
+/// Find a common coerceable type for all `then_types` as well
+/// and the `else_type`, if specified.
+/// Returns the common data type for `then_types` and `else_type`
+pub fn get_coerce_type_for_case_when(
+ then_types: &[DataType],
+ else_type: &Option<DataType>,
+) -> Option<DataType> {
+ let else_type = match else_type {
+ None => then_types[0].clone(),
+ Some(data_type) => data_type.clone(),
+ };
+ then_types
+ .iter()
+ .fold(Some(else_type), |left, right_type| match left {
+ // failed to find a valid coercion in a previous iteration
+ None => None,
+ // TODO: now just use the `equal` coercion rule for case when. If find the issue, and
+ // refactor again.
+ Some(left_type) => comparison_coercion(&left_type, right_type),
+ })
+}
diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs
index 414f4bf6f..c37653ab0 100644
--- a/datafusion/expr/src/window_function.rs
+++ b/datafusion/expr/src/window_function.rs
@@ -22,7 +22,7 @@
//!
use crate::aggregate_function::AggregateFunction;
-use crate::type_coercion::data_types;
+use crate::type_coercion::functions::data_types;
use crate::{aggregate_function, Signature, TypeSignature, Volatility};
use arrow::datatypes::DataType;
use datafusion_common::{DataFusionError, Result};
diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs
index 2073713dd..bb236fdde 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -20,10 +20,13 @@
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
-use datafusion_expr::binary_rule::{coerce_types, comparison_coercion};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use datafusion_expr::logical_plan::Subquery;
-use datafusion_expr::type_coercion::data_types;
+use datafusion_expr::type_coercion::binary::{coerce_types, comparison_coercion};
+use datafusion_expr::type_coercion::functions::data_types;
+use datafusion_expr::type_coercion::other::{
+ get_coerce_type_for_case_when, get_coerce_type_for_list,
+};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, Expr,
@@ -412,21 +415,6 @@ fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> Result<Expr
expr.clone().cast_to(&coerced_type, schema)
}
-/// Attempts to coerce the types of `list_types` to be comparable with the
-/// `expr_type`.
-/// Returns the common data type for `expr_type` and `list_types`
-fn get_coerce_type_for_list(
- expr_type: &DataType,
- list_types: &[DataType],
-) -> Option<DataType> {
- list_types
- .iter()
- .fold(Some(expr_type.clone()), |left, right_type| match left {
- None => None,
- Some(left_type) => comparison_coercion(&left_type, right_type),
- })
-}
-
/// Returns `expressions` coerced to types compatible with
/// `signature`, if possible.
///
@@ -454,28 +442,6 @@ fn coerce_arguments_for_signature(
.collect::<Result<Vec<_>>>()
}
-/// Find a common coerceable type for all `then_types` as well
-/// and the `else_type`, if specified.
-/// Returns the common data type for `then_types` and `else_type`
-fn get_coerce_type_for_case_when(
- then_types: &[DataType],
- else_type: &Option<DataType>,
-) -> Option<DataType> {
- let else_type = match else_type {
- None => then_types[0].clone(),
- Some(data_type) => data_type.clone(),
- };
- then_types
- .iter()
- .fold(Some(else_type), |left, right_type| match left {
- // failed to find a valid coercion in a previous iteration
- None => None,
- // TODO: now just use the `equal` coercion rule for case when. If find the issue, and
- // refactor again.
- Some(left_type) => comparison_coercion(&left_type, right_type),
- })
-}
-
#[cfg(test)]
mod test {
use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter};
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs
index e6635698c..e3154488c 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -287,7 +287,7 @@ mod tests {
};
use arrow::datatypes::{DataType, Field};
use datafusion_common::ScalarValue;
- use datafusion_expr::aggregate_function::NUMERICS;
+ use datafusion_expr::type_coercion::aggregates::NUMERICS;
#[test]
fn test_count_arragg_approx_expr() -> Result<()> {
diff --git a/datafusion/physical-expr/src/aggregate/coercion_rule.rs b/datafusion/physical-expr/src/aggregate/coercion_rule.rs
index 7b1e26ed1..a8c68390a 100644
--- a/datafusion/physical-expr/src/aggregate/coercion_rule.rs
+++ b/datafusion/physical-expr/src/aggregate/coercion_rule.rs
@@ -21,7 +21,7 @@ use crate::expressions::try_cast;
use crate::PhysicalExpr;
use arrow::datatypes::Schema;
use datafusion_common::Result;
-use datafusion_expr::{aggregate_function, AggregateFunction, Signature};
+use datafusion_expr::{type_coercion, AggregateFunction, Signature};
use std::sync::Arc;
/// Returns the coerced exprs for each `input_exprs`.
@@ -43,7 +43,7 @@ pub fn coerce_exprs(
// get the coerced data types
let coerced_types =
- aggregate_function::coerce_types(agg_fun, &input_types, signature)?;
+ type_coercion::aggregates::coerce_types(agg_fun, &input_types, signature)?;
// try cast if need
input_exprs
diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs
index 02bf0e5bd..665ef2d68 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -76,7 +76,7 @@ use arrow::record_batch::RecordBatch;
use crate::PhysicalExpr;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::binary_rule::binary_operator_data_type;
+use datafusion_expr::type_coercion::binary::binary_operator_data_type;
use datafusion_expr::{ColumnarValue, Operator};
/// Binary expression
@@ -938,7 +938,7 @@ mod tests {
use crate::expressions::{col, lit};
use arrow::datatypes::{ArrowNumericType, Field, Int32Type, SchemaRef};
use datafusion_common::Result;
- use datafusion_expr::binary_rule::coerce_types;
+ use datafusion_expr::type_coercion::binary::coerce_types;
// Create a binary expression without coercion. Used here when we do not want to coerce the expressions
// to valid types. Usage can result in an execution (after plan) error.
diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs
index b1bd0a604..581d9abdc 100644
--- a/datafusion/physical-expr/src/expressions/case.rs
+++ b/datafusion/physical-expr/src/expressions/case.rs
@@ -308,7 +308,7 @@ mod tests {
use arrow::datatypes::DataType::Float64;
use arrow::datatypes::*;
use datafusion_common::ScalarValue;
- use datafusion_expr::binary_rule::comparison_coercion;
+ use datafusion_expr::type_coercion::binary::comparison_coercion;
use datafusion_expr::Operator;
#[test]
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs
index ae49b4cec..1e9fa7e9b 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -930,7 +930,7 @@ mod tests {
use crate::expressions;
use crate::expressions::{col, lit, try_cast};
use datafusion_common::Result;
- use datafusion_expr::binary_rule::comparison_coercion;
+ use datafusion_expr::type_coercion::binary::comparison_coercion;
type InListCastResult = (Arc<dyn PhysicalExpr>, Vec<Arc<dyn PhysicalExpr>>);
diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs
index 8619df1b3..233ad4254 100644
--- a/datafusion/physical-expr/src/expressions/negative.rs
+++ b/datafusion/physical-expr/src/expressions/negative.rs
@@ -30,7 +30,7 @@ use arrow::{
use crate::PhysicalExpr;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::{binary_rule::is_signed_numeric, ColumnarValue};
+use datafusion_expr::{type_coercion::is_signed_numeric, ColumnarValue};
/// Invoke a compute kernel on array(s)
macro_rules! compute_op {
diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs
index 6997adc46..5796f8f7d 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -26,8 +26,9 @@
//! * Signature: see `Signature`
//! * Return type: a function `(arg_types) -> return_type`. E.g. for sqrt, ([f32]) -> f32, ([f64]) -> f64.
//!
-//! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed
-//! to a function that supports f64, it is coerced to f64.
+//! This module also supports coercion to improve user experience: if
+//! an argument i32 is passed to a function that supports f64, the
+//! argument is automatically is coerced to f64.
use crate::execution_props::ExecutionProps;
use crate::{
diff --git a/datafusion/physical-expr/src/type_coercion.rs b/datafusion/physical-expr/src/type_coercion.rs
index c7648cc26..dca6ad086 100644
--- a/datafusion/physical-expr/src/type_coercion.rs
+++ b/datafusion/physical-expr/src/type_coercion.rs
@@ -33,7 +33,7 @@ use super::PhysicalExpr;
use crate::expressions::try_cast;
use arrow::datatypes::Schema;
use datafusion_common::Result;
-use datafusion_expr::{type_coercion::data_types, Signature};
+use datafusion_expr::{type_coercion::functions::data_types, Signature};
use std::{sync::Arc, vec};
/// Returns `expressions` coerced to types compatible with