You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/04/18 02:12:38 UTC

[GitHub] [arrow-datafusion] yjshen commented on a diff in pull request #2257: Move logical expression type-coercion code from `physical-expr` crate to `expr` crate

yjshen commented on code in PR #2257:
URL: https://github.com/apache/arrow-datafusion/pull/2257#discussion_r851839766


##########
datafusion/expr/src/aggregate_function.rs:
##########
@@ -101,3 +130,733 @@ impl FromStr for AggregateFunction {
         })
     }
 }
+
+/// Returns the datatype of the aggregate function.
+/// This is used to get the returned data type for aggregate expr.
+pub fn return_type(
+    fun: &AggregateFunction,
+    input_expr_types: &[DataType],
+) -> Result<DataType> {
+    // Note that this function *must* return the same type that the respective physical expression returns
+    // or the execution panics.
+
+    let coerced_data_types = coerce_types(fun, input_expr_types, &signature(fun))?;
+
+    match fun {
+        // TODO If the datafusion is compatible with PostgreSQL, the returned data type should be INT64.
+        AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
+            Ok(DataType::UInt64)
+        }
+        AggregateFunction::Max | AggregateFunction::Min => {
+            // For min and max agg function, the returned type is same as input type.
+            // The coerced_data_types is same with input_types.
+            Ok(coerced_data_types[0].clone())
+        }
+        AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]),
+        AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]),
+        AggregateFunction::VariancePop => variance_return_type(&coerced_data_types[0]),
+        AggregateFunction::Covariance => covariance_return_type(&coerced_data_types[0]),
+        AggregateFunction::CovariancePop => {
+            covariance_return_type(&coerced_data_types[0])
+        }
+        AggregateFunction::Correlation => correlation_return_type(&coerced_data_types[0]),
+        AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]),
+        AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]),
+        AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
+        AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new(
+            "item",
+            coerced_data_types[0].clone(),
+            true,
+        )))),
+        AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()),
+        AggregateFunction::ApproxPercentileContWithWeight => {
+            Ok(coerced_data_types[0].clone())
+        }
+        AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
+    }
+}
+
+/// Returns the coerced data type for each `input_types`.
+/// Different aggregate function with different input data type will get corresponding coerced data type.
+pub fn coerce_types(
+    agg_fun: &AggregateFunction,
+    input_types: &[DataType],
+    signature: &Signature,
+) -> Result<Vec<DataType>> {
+    // Validate input_types matches (at least one of) the func signature.
+    check_arg_count(agg_fun, input_types, &signature.type_signature)?;
+
+    match agg_fun {
+        AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
+            Ok(input_types.to_vec())
+        }
+        AggregateFunction::ArrayAgg => Ok(input_types.to_vec()),
+        AggregateFunction::Min | AggregateFunction::Max => {
+            // min and max support the dictionary data type
+            // unpack the dictionary to get the value
+            get_min_max_result_type(input_types)
+        }
+        AggregateFunction::Sum => {
+            // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
+            // smallint, int, bigint, real, double precision, decimal, or interval.
+            if !is_sum_support_arg_type(&input_types[0]) {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} does not support inputs of type {:?}.",
+                    agg_fun, input_types[0]
+                )));
+            }
+            Ok(input_types.to_vec())
+        }
+        AggregateFunction::Avg => {
+            // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
+            // smallint, int, bigint, real, double precision, decimal, or interval
+            if !is_avg_support_arg_type(&input_types[0]) {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} does not support inputs of type {:?}.",
+                    agg_fun, input_types[0]
+                )));
+            }
+            Ok(input_types.to_vec())
+        }
+        AggregateFunction::Variance => {
+            if !is_variance_support_arg_type(&input_types[0]) {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} does not support inputs of type {:?}.",
+                    agg_fun, input_types[0]
+                )));
+            }
+            Ok(input_types.to_vec())
+        }
+        AggregateFunction::VariancePop => {
+            if !is_variance_support_arg_type(&input_types[0]) {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} does not support inputs of type {:?}.",
+                    agg_fun, input_types[0]
+                )));
+            }
+            Ok(input_types.to_vec())
+        }
+        AggregateFunction::Covariance => {
+            if !is_covariance_support_arg_type(&input_types[0]) {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} does not support inputs of type {:?}.",
+                    agg_fun, input_types[0]
+                )));
+            }
+            Ok(input_types.to_vec())
+        }
+        AggregateFunction::CovariancePop => {
+            if !is_covariance_support_arg_type(&input_types[0]) {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} does not support inputs of type {:?}.",
+                    agg_fun, input_types[0]
+                )));
+            }
+            Ok(input_types.to_vec())
+        }
+        AggregateFunction::Stddev => {
+            if !is_stddev_support_arg_type(&input_types[0]) {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} does not support inputs of type {:?}.",
+                    agg_fun, input_types[0]
+                )));
+            }
+            Ok(input_types.to_vec())
+        }
+        AggregateFunction::StddevPop => {
+            if !is_stddev_support_arg_type(&input_types[0]) {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} does not support inputs of type {:?}.",
+                    agg_fun, input_types[0]
+                )));
+            }
+            Ok(input_types.to_vec())
+        }
+        AggregateFunction::Correlation => {
+            if !is_correlation_support_arg_type(&input_types[0]) {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} does not support inputs of type {:?}.",
+                    agg_fun, input_types[0]
+                )));
+            }
+            Ok(input_types.to_vec())
+        }
+        AggregateFunction::ApproxPercentileCont => {
+            if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} does not support inputs of type {:?}.",
+                    agg_fun, input_types[0]
+                )));
+            }
+            if !matches!(input_types[1], DataType::Float64) {
+                return Err(DataFusionError::Plan(format!(
+                    "The percentile argument for {:?} must be Float64, not {:?}.",
+                    agg_fun, input_types[1]
+                )));
+            }
+            Ok(input_types.to_vec())
+        }
+        AggregateFunction::ApproxPercentileContWithWeight => {
+            if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} does not support inputs of type {:?}.",
+                    agg_fun, input_types[0]
+                )));
+            }
+            if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) {
+                return Err(DataFusionError::Plan(format!(
+                    "The weight argument for {:?} does not support inputs of type {:?}.",
+                    agg_fun, input_types[1]
+                )));
+            }
+            if !matches!(input_types[2], DataType::Float64) {
+                return Err(DataFusionError::Plan(format!(
+                    "The percentile argument for {:?} must be Float64, not {:?}.",
+                    agg_fun, input_types[2]
+                )));
+            }
+            Ok(input_types.to_vec())
+        }
+        AggregateFunction::ApproxMedian => {
+            if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} does not support inputs of type {:?}.",
+                    agg_fun, input_types[0]
+                )));
+            }
+            Ok(input_types.to_vec())
+        }
+    }
+}
+
+/// the signatures supported by the function `fun`.
+pub fn signature(fun: &AggregateFunction) -> Signature {
+    // note: the physical expression must accept the type returned by this function or the execution panics.
+    match fun {
+        AggregateFunction::Count
+        | AggregateFunction::ApproxDistinct
+        | AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable),
+        AggregateFunction::Min | AggregateFunction::Max => {
+            let valid = STRINGS
+                .iter()
+                .chain(NUMERICS.iter())
+                .chain(TIMESTAMPS.iter())
+                .chain(DATES.iter())
+                .cloned()
+                .collect::<Vec<_>>();
+            Signature::uniform(1, valid, Volatility::Immutable)
+        }
+        AggregateFunction::Avg
+        | AggregateFunction::Sum
+        | AggregateFunction::Variance
+        | AggregateFunction::VariancePop
+        | AggregateFunction::Stddev
+        | AggregateFunction::StddevPop
+        | AggregateFunction::ApproxMedian => {
+            Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
+        }
+        AggregateFunction::Covariance | AggregateFunction::CovariancePop => {
+            Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
+        }
+        AggregateFunction::Correlation => {
+            Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
+        }
+        AggregateFunction::ApproxPercentileCont => Signature::one_of(
+            // Accept any numeric value paired with a float64 percentile
+            NUMERICS
+                .iter()
+                .map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64]))
+                .collect(),
+            Volatility::Immutable,
+        ),
+        AggregateFunction::ApproxPercentileContWithWeight => Signature::one_of(
+            // Accept any numeric value paired with a float64 percentile
+            NUMERICS
+                .iter()
+                .map(|t| {
+                    TypeSignature::Exact(vec![t.clone(), t.clone(), DataType::Float64])
+                })
+                .collect(),
+            Volatility::Immutable,
+        ),
+    }
+}
+
+/// function return type of a sum
+pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
+    match arg_type {
+        DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
+            Ok(DataType::Int64)
+        }
+        DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
+            Ok(DataType::UInt64)
+        }
+        // In the https://www.postgresql.org/docs/current/functions-aggregate.html doc,
+        // the result type of floating-point is FLOAT64 with the double precision.
+        DataType::Float64 | DataType::Float32 => Ok(DataType::Float64),
+        DataType::Decimal(precision, scale) => {
+            // in the spark, the result type is DECIMAL(min(38,precision+10), s)
+            // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
+            let new_precision = DECIMAL_MAX_PRECISION.min(*precision + 10);
+            Ok(DataType::Decimal(new_precision, *scale))
+        }
+        other => Err(DataFusionError::Plan(format!(
+            "SUM does not support type \"{:?}\"",
+            other
+        ))),
+    }
+}
+
+/// function return type of variance
+pub fn variance_return_type(arg_type: &DataType) -> Result<DataType> {
+    match arg_type {
+        DataType::Int8
+        | DataType::Int16
+        | DataType::Int32
+        | DataType::Int64
+        | DataType::UInt8
+        | DataType::UInt16
+        | DataType::UInt32
+        | DataType::UInt64
+        | DataType::Float32
+        | DataType::Float64 => Ok(DataType::Float64),
+        other => Err(DataFusionError::Plan(format!(
+            "VAR does not support {:?}",
+            other
+        ))),
+    }
+}
+
+/// function return type of covariance
+pub fn covariance_return_type(arg_type: &DataType) -> Result<DataType> {
+    match arg_type {
+        DataType::Int8
+        | DataType::Int16
+        | DataType::Int32
+        | DataType::Int64
+        | DataType::UInt8
+        | DataType::UInt16
+        | DataType::UInt32
+        | DataType::UInt64
+        | DataType::Float32
+        | DataType::Float64 => Ok(DataType::Float64),
+        other => Err(DataFusionError::Plan(format!(
+            "COVAR does not support {:?}",
+            other
+        ))),
+    }
+}
+
+/// function return type of correlation
+pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> {
+    match arg_type {
+        DataType::Int8
+        | DataType::Int16
+        | DataType::Int32
+        | DataType::Int64
+        | DataType::UInt8
+        | DataType::UInt16
+        | DataType::UInt32
+        | DataType::UInt64
+        | DataType::Float32
+        | DataType::Float64 => Ok(DataType::Float64),
+        other => Err(DataFusionError::Plan(format!(
+            "CORR does not support {:?}",
+            other
+        ))),
+    }
+}
+
+/// function return type of standard deviation
+pub fn stddev_return_type(arg_type: &DataType) -> Result<DataType> {
+    match arg_type {
+        DataType::Int8
+        | DataType::Int16
+        | DataType::Int32
+        | DataType::Int64
+        | DataType::UInt8
+        | DataType::UInt16
+        | DataType::UInt32
+        | DataType::UInt64
+        | DataType::Float32
+        | DataType::Float64 => Ok(DataType::Float64),
+        other => Err(DataFusionError::Plan(format!(
+            "STDDEV does not support {:?}",
+            other
+        ))),
+    }
+}
+
+/// function return type of an average
+pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
+    match arg_type {
+        DataType::Decimal(precision, scale) => {
+            // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
+            // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
+            let new_precision = DECIMAL_MAX_PRECISION.min(*precision + 4);
+            let new_scale = DECIMAL_MAX_SCALE.min(*scale + 4);
+            Ok(DataType::Decimal(new_precision, new_scale))
+        }
+        DataType::Int8
+        | DataType::Int16
+        | DataType::Int32
+        | DataType::Int64
+        | DataType::UInt8
+        | DataType::UInt16
+        | DataType::UInt32
+        | DataType::UInt64
+        | DataType::Float32
+        | DataType::Float64 => Ok(DataType::Float64),
+        other => Err(DataFusionError::Plan(format!(
+            "AVG does not support {:?}",
+            other
+        ))),
+    }
+}
+
+/// Validate the length of `input_types` matches the `signature` for `agg_fun`.
+///
+/// This method DOES NOT validate the argument types - only that (at least one,
+/// in the case of [`TypeSignature::OneOf`]) signature matches the desired
+/// number of input types.
+fn check_arg_count(
+    agg_fun: &AggregateFunction,
+    input_types: &[DataType],
+    signature: &TypeSignature,
+) -> Result<()> {
+    match signature {
+        TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
+            if input_types.len() != *agg_count {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} expects {:?} arguments, but {:?} were provided",
+                    agg_fun,
+                    agg_count,
+                    input_types.len()
+                )));
+            }
+        }
+        TypeSignature::Exact(types) => {
+            if types.len() != input_types.len() {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} expects {:?} arguments, but {:?} were provided",
+                    agg_fun,
+                    types.len(),
+                    input_types.len()
+                )));
+            }
+        }
+        TypeSignature::OneOf(variants) => {
+            let ok = variants
+                .iter()
+                .any(|v| check_arg_count(agg_fun, input_types, v).is_ok());
+            if !ok {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} does not accept {:?} function arguments.",
+                    agg_fun,
+                    input_types.len()
+                )));
+            }
+        }
+        _ => {
+            return Err(DataFusionError::Internal(format!(
+                "Aggregate functions do not support this {:?}",
+                signature
+            )));
+        }
+    }
+    Ok(())
+}
+
+fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
+    // make sure that the input types only has one element.
+    assert_eq!(input_types.len(), 1);
+    // min and max support the dictionary data type
+    // unpack the dictionary to get the value
+    match &input_types[0] {
+        DataType::Dictionary(_, dict_value_type) => {
+            // TODO add checker, if the value type is complex data type
+            Ok(vec![dict_value_type.deref().clone()])
+        }
+        // TODO add checker for datatype which min and max supported
+        // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function
+        _ => Ok(input_types.to_vec()),
+    }
+}
+
+pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
+    matches!(
+        arg_type,
+        DataType::UInt8
+            | DataType::UInt16
+            | DataType::UInt32
+            | DataType::UInt64
+            | DataType::Int8
+            | DataType::Int16
+            | DataType::Int32
+            | DataType::Int64
+            | DataType::Float32
+            | DataType::Float64
+            | DataType::Decimal(_, _)
+    )
+}
+
+pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
+    matches!(
+        arg_type,
+        DataType::UInt8
+            | DataType::UInt16
+            | DataType::UInt32
+            | DataType::UInt64
+            | DataType::Int8
+            | DataType::Int16
+            | DataType::Int32
+            | DataType::Int64
+            | DataType::Float32
+            | DataType::Float64
+            | DataType::Decimal(_, _)
+    )
+}
+
+pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool {
+    matches!(
+        arg_type,
+        DataType::UInt8
+            | DataType::UInt16
+            | DataType::UInt32
+            | DataType::UInt64
+            | DataType::Int8
+            | DataType::Int16
+            | DataType::Int32
+            | DataType::Int64
+            | DataType::Float32
+            | DataType::Float64
+    )
+}
+
+pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool {
+    matches!(
+        arg_type,
+        DataType::UInt8
+            | DataType::UInt16
+            | DataType::UInt32
+            | DataType::UInt64
+            | DataType::Int8
+            | DataType::Int16
+            | DataType::Int32
+            | DataType::Int64
+            | DataType::Float32
+            | DataType::Float64
+    )
+}
+
+pub fn is_stddev_support_arg_type(arg_type: &DataType) -> bool {
+    matches!(
+        arg_type,
+        DataType::UInt8
+            | DataType::UInt16
+            | DataType::UInt32
+            | DataType::UInt64
+            | DataType::Int8
+            | DataType::Int16
+            | DataType::Int32
+            | DataType::Int64
+            | DataType::Float32
+            | DataType::Float64
+    )
+}
+
+pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool {
+    matches!(
+        arg_type,
+        DataType::UInt8
+            | DataType::UInt16
+            | DataType::UInt32
+            | DataType::UInt64
+            | DataType::Int8
+            | DataType::Int16
+            | DataType::Int32
+            | DataType::Int64
+            | DataType::Float32
+            | DataType::Float64
+    )
+}
+
+/// Return `true` if `arg_type` is of a [`DataType`] that the
+/// [`ApproxPercentileCont`] aggregation can operate on.
+pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool {

Review Comment:
   I've thought of a trait with `supported_arg_type` and `return_type` methods before but seems not much difference.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org