You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2023/07/23 16:51:45 UTC

[arrow-datafusion] branch main updated: Add more Decimal256 type coercion (#7047)

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new b135a96a96 Add more Decimal256 type coercion (#7047)
b135a96a96 is described below

commit b135a96a9647d0b7df6099bbd222e5280e75d0a8
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Sun Jul 23 09:51:38 2023 -0700

    Add more Decimal256 type coercion (#7047)
---
 datafusion/expr/src/type_coercion/binary.rs | 60 +++++++++++++++++++++++++++--
 1 file changed, 57 insertions(+), 3 deletions(-)

diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs
index c510822445..b6392e2a6b 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -20,6 +20,7 @@
 use arrow::compute::can_cast_types;
 use arrow::datatypes::{
     DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
+    DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
 };
 
 use datafusion_common::DataFusionError;
@@ -248,6 +249,17 @@ fn math_decimal_coercion(
         (Int8 | Int16 | Int32 | Int64, Decimal128(_, _)) => {
             Some((coerce_numeric_type_to_decimal(lhs_type)?, rhs_type.clone()))
         }
+        (Decimal256(_, _), Decimal256(_, _)) => {
+            Some((lhs_type.clone(), rhs_type.clone()))
+        }
+        (Decimal256(_, _), Int8 | Int16 | Int32 | Int64) => Some((
+            lhs_type.clone(),
+            coerce_numeric_type_to_decimal256(rhs_type)?,
+        )),
+        (Int8 | Int16 | Int32 | Int64, Decimal256(_, _)) => Some((
+            coerce_numeric_type_to_decimal256(lhs_type)?,
+            rhs_type.clone(),
+        )),
         _ => None,
     }
 }
@@ -383,6 +395,11 @@ fn comparison_binary_numeric_coercion(
         }
         (Decimal128(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type),
         (_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type),
+        (Decimal256(_, _), Decimal256(_, _)) => {
+            get_wider_decimal_type(lhs_type, rhs_type)
+        }
+        (Decimal256(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type),
+        (_, Decimal256(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type),
         (Float64, _) | (_, Float64) => Some(Float64),
         (_, Float32) | (Float32, _) => Some(Float32),
         // The following match arms encode the following logic: Given the two
@@ -427,9 +444,15 @@ fn get_comparison_common_decimal_type(
     other_type: &DataType,
 ) -> Option<DataType> {
     use arrow::datatypes::DataType::*;
-    let other_decimal_type = coerce_numeric_type_to_decimal(other_type)?;
-    match (decimal_type, &other_decimal_type) {
-        (d1 @ Decimal128(_, _), d2 @ Decimal128(_, _)) => get_wider_decimal_type(d1, d2),
+    match decimal_type {
+        Decimal128(_, _) => {
+            let other_decimal_type = coerce_numeric_type_to_decimal(other_type)?;
+            get_wider_decimal_type(decimal_type, &other_decimal_type)
+        }
+        Decimal256(_, _) => {
+            let other_decimal_type = coerce_numeric_type_to_decimal256(other_type)?;
+            get_wider_decimal_type(decimal_type, &other_decimal_type)
+        }
         _ => None,
     }
 }
@@ -449,6 +472,12 @@ fn get_wider_decimal_type(
             let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
             Some(create_decimal_type((range + s) as u8, s))
         }
+        (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => {
+            // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
+            let s = *s1.max(s2);
+            let range = (*p1 as i8 - s1).max(*p2 as i8 - s2);
+            Some(create_decimal256_type((range + s) as u8, s))
+        }
         (_, _) => None,
     }
 }
@@ -471,6 +500,24 @@ fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {
     }
 }
 
+/// Convert the numeric data type to the decimal data type.
+/// Now, we just support the signed integer type and floating-point type.
+fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option<DataType> {
+    use arrow::datatypes::DataType::*;
+    // This conversion rule is from spark
+    // https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127
+    match numeric_type {
+        Int8 => Some(Decimal256(3, 0)),
+        Int16 => Some(Decimal256(5, 0)),
+        Int32 => Some(Decimal256(10, 0)),
+        Int64 => Some(Decimal256(20, 0)),
+        // TODO if we convert the floating-point data to the decimal type, it maybe overflow.
+        Float32 => Some(Decimal256(14, 7)),
+        Float64 => Some(Decimal256(30, 15)),
+        _ => None,
+    }
+}
+
 /// Returns the output type of applying mathematics operations such as
 /// `+` to arguments of `lhs_type` and `rhs_type`.
 fn mathematics_numerical_coercion(
@@ -517,6 +564,13 @@ fn create_decimal_type(precision: u8, scale: i8) -> DataType {
     )
 }
 
+fn create_decimal256_type(precision: u8, scale: i8) -> DataType {
+    DataType::Decimal256(
+        DECIMAL256_MAX_PRECISION.min(precision),
+        DECIMAL256_MAX_SCALE.min(scale),
+    )
+}
+
 /// Returns the coerced type of applying mathematics operations on decimal types.
 /// Two sides of the mathematics operation will be coerced to the same type. Note
 /// that we don't coerce the decimal operands in analysis phase, but do it in the