You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2023/07/23 16:51:45 UTC
[arrow-datafusion] branch main updated: Add more Decimal256 type coercion (#7047)
This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new b135a96a96 Add more Decimal256 type coercion (#7047)
b135a96a96 is described below
commit b135a96a9647d0b7df6099bbd222e5280e75d0a8
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Sun Jul 23 09:51:38 2023 -0700
Add more Decimal256 type coercion (#7047)
---
datafusion/expr/src/type_coercion/binary.rs | 60 +++++++++++++++++++++++++++--
1 file changed, 57 insertions(+), 3 deletions(-)
diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs
index c510822445..b6392e2a6b 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -20,6 +20,7 @@
use arrow::compute::can_cast_types;
use arrow::datatypes::{
DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
+ DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
};
use datafusion_common::DataFusionError;
@@ -248,6 +249,17 @@ fn math_decimal_coercion(
(Int8 | Int16 | Int32 | Int64, Decimal128(_, _)) => {
Some((coerce_numeric_type_to_decimal(lhs_type)?, rhs_type.clone()))
}
+ (Decimal256(_, _), Decimal256(_, _)) => {
+ Some((lhs_type.clone(), rhs_type.clone()))
+ }
+ (Decimal256(_, _), Int8 | Int16 | Int32 | Int64) => Some((
+ lhs_type.clone(),
+ coerce_numeric_type_to_decimal256(rhs_type)?,
+ )),
+ (Int8 | Int16 | Int32 | Int64, Decimal256(_, _)) => Some((
+ coerce_numeric_type_to_decimal256(lhs_type)?,
+ rhs_type.clone(),
+ )),
_ => None,
}
}
@@ -383,6 +395,11 @@ fn comparison_binary_numeric_coercion(
}
(Decimal128(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type),
(_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type),
+ (Decimal256(_, _), Decimal256(_, _)) => {
+ get_wider_decimal_type(lhs_type, rhs_type)
+ }
+ (Decimal256(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type),
+ (_, Decimal256(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type),
(Float64, _) | (_, Float64) => Some(Float64),
(_, Float32) | (Float32, _) => Some(Float32),
// The following match arms encode the following logic: Given the two
@@ -427,9 +444,15 @@ fn get_comparison_common_decimal_type(
other_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
- let other_decimal_type = coerce_numeric_type_to_decimal(other_type)?;
- match (decimal_type, &other_decimal_type) {
- (d1 @ Decimal128(_, _), d2 @ Decimal128(_, _)) => get_wider_decimal_type(d1, d2),
+ match decimal_type {
+ Decimal128(_, _) => {
+ let other_decimal_type = coerce_numeric_type_to_decimal(other_type)?;
+ get_wider_decimal_type(decimal_type, &other_decimal_type)
+ }
+ Decimal256(_, _) => {
+ let other_decimal_type = coerce_numeric_type_to_decimal256(other_type)?;
+ get_wider_decimal_type(decimal_type, &other_decimal_type)
+ }
_ => None,
}
}
@@ -449,6 +472,12 @@ fn get_wider_decimal_type(
let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
Some(create_decimal_type((range + s) as u8, s))
}
+ (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => {
+ // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
+ let s = *s1.max(s2);
+ let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
+ Some(create_decimal256_type((range + s) as u8, s))
+ }
(_, _) => None,
}
}
@@ -471,6 +500,24 @@ fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {
}
}
+/// Convert the numeric data type to the decimal data type.
+/// Now, we just support the signed integer type and floating-point type.
+fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option<DataType> {
+ use arrow::datatypes::DataType::*;
+ // This conversion rule is from spark
+ // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127
+ match numeric_type {
+ Int8 => Some(Decimal256(3, 0)),
+ Int16 => Some(Decimal256(5, 0)),
+ Int32 => Some(Decimal256(10, 0)),
+ Int64 => Some(Decimal256(20, 0)),
+ // TODO if we convert the floating-point data to the decimal type, it maybe overflow.
+ Float32 => Some(Decimal256(14, 7)),
+ Float64 => Some(Decimal256(30, 15)),
+ _ => None,
+ }
+}
+
/// Returns the output type of applying mathematics operations such as
/// `+` to arguments of `lhs_type` and `rhs_type`.
fn mathematics_numerical_coercion(
@@ -517,6 +564,13 @@ fn create_decimal_type(precision: u8, scale: i8) -> DataType {
)
}
+fn create_decimal256_type(precision: u8, scale: i8) -> DataType {
+ DataType::Decimal256(
+ DECIMAL256_MAX_PRECISION.min(precision),
+ DECIMAL256_MAX_SCALE.min(scale),
+ )
+}
+
/// Returns the coerced type of applying mathematics operations on decimal types.
/// Two sides of the mathematics operation will be coerced to the same type. Note
/// that we don't coerce the decimal operands in analysis phase, but do it in the