You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by ga...@apache.org on 2022/12/29 07:35:24 UTC

[doris] branch master updated: [Bug](Decimalv3) coredump of decimalv3 multiply (#15452)

This is an automated email from the ASF dual-hosted git repository.

gabriellee pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new c22ba8e160 [Bug](Decimalv3) coredump of decimalv3 multiply (#15452)
c22ba8e160 is described below

commit c22ba8e1608a561f0e46d4100bc2a31f162fc307
Author: HappenLee <ha...@hotmail.com>
AuthorDate: Thu Dec 29 15:35:17 2022 +0800

    [Bug](Decimalv3) coredump of decimalv3 multiply (#15452)
---
 be/src/vec/core/decimal_comparison.h               |  8 +--
 be/src/vec/data_types/data_type_decimal.h          | 60 ++++++----------------
 be/src/vec/functions/function_binary_arithmetic.h  |  7 ++-
 regression-test/data/decimalv3/test_decimalv3.out  |  3 ++
 .../suites/decimalv3/test_decimalv3.groovy         |  1 +
 5 files changed, 29 insertions(+), 50 deletions(-)

diff --git a/be/src/vec/core/decimal_comparison.h b/be/src/vec/core/decimal_comparison.h
index 1f2bce4200..4c7cfcf765 100644
--- a/be/src/vec/core/decimal_comparison.h
+++ b/be/src/vec/core/decimal_comparison.h
@@ -142,9 +142,11 @@ private:
 
         Shift shift;
         if (decimal0 && decimal1) {
-            auto result_type = decimal_result_type(*decimal0, *decimal1, false, false);
-            shift.a = result_type.scale_factor_for(*decimal0, false);
-            shift.b = result_type.scale_factor_for(*decimal1, false);
+            using Type = std::conditional_t<sizeof(T) >= sizeof(U), T, U>;
+            auto type_ptr = decimal_result_type(*decimal0, *decimal1, false, false, false);
+            const DataTypeDecimal<Type>* result_type = check_decimal<Type>(*type_ptr);
+            shift.a = result_type->scale_factor_for(*decimal0, false);
+            shift.b = result_type->scale_factor_for(*decimal1, false);
         } else if (decimal0) {
             shift.b = decimal0->get_scale_multiplier();
         } else if (decimal1) {
diff --git a/be/src/vec/data_types/data_type_decimal.h b/be/src/vec/data_types/data_type_decimal.h
index ea29954293..c8e08303a1 100644
--- a/be/src/vec/data_types/data_type_decimal.h
+++ b/be/src/vec/data_types/data_type_decimal.h
@@ -219,56 +219,30 @@ private:
 };
 
 template <typename T, typename U>
-typename std::enable_if_t<(sizeof(T) >= sizeof(U)), const DataTypeDecimal<T>> decimal_result_type(
-        const DataTypeDecimal<T>& tx, const DataTypeDecimal<U>& ty, bool is_multiply,
-        bool is_divide) {
+DataTypePtr decimal_result_type(const DataTypeDecimal<T>& tx, const DataTypeDecimal<U>& ty,
+                                bool is_multiply, bool is_divide, bool is_plus_minus) {
+    using Type = std::conditional_t<sizeof(T) >= sizeof(U), T, U>;
     if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) {
-        return DataTypeDecimal<T>(max_decimal_precision<T>(), 9);
+        return std::make_shared<DataTypeDecimal<Type>>((max_decimal_precision<T>(), 9));
     } else {
-        UInt32 scale = (tx.get_scale() > ty.get_scale() ? tx.get_scale() : ty.get_scale());
+        UInt32 scale = std::max(tx.get_scale(), ty.get_scale());
+        auto precision = max_decimal_precision<Type>();
+
+        size_t multiply_precision = tx.get_precision() + ty.get_precision();
+        size_t divide_precision = tx.get_precision() + ty.get_scale();
+        size_t plus_minus_precision =
+                std::max(tx.get_precision() - tx.get_scale(), ty.get_precision() - ty.get_scale()) +
+                scale;
         if (is_multiply) {
             scale = tx.get_scale() + ty.get_scale();
+            precision = std::min(multiply_precision, max_decimal_precision<Decimal128I>());
         } else if (is_divide) {
             scale = tx.get_scale();
+            precision = std::min(divide_precision, max_decimal_precision<Decimal128I>());
+        } else if (is_plus_minus) {
+            precision = std::min(plus_minus_precision, max_decimal_precision<Decimal128I>());
         }
-        return DataTypeDecimal<T>(max_decimal_precision<T>(), scale);
-    }
-}
-
-template <typename T, typename U>
-typename std::enable_if_t<(sizeof(T) < sizeof(U)), const DataTypeDecimal<U>> decimal_result_type(
-        const DataTypeDecimal<T>& tx, const DataTypeDecimal<U>& ty, bool is_multiply,
-        bool is_divide) {
-    if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) {
-        return DataTypeDecimal<U>(max_decimal_precision<U>(), 9);
-    } else {
-        UInt32 scale = (tx.get_scale() > ty.get_scale() ? tx.get_scale() : ty.get_scale());
-        if (is_multiply) {
-            scale = tx.get_scale() + ty.get_scale();
-        } else if (is_divide) {
-            scale = tx.get_scale();
-        }
-        return DataTypeDecimal<U>(max_decimal_precision<U>(), scale);
-    }
-}
-
-template <typename T, typename U>
-const DataTypeDecimal<T> decimal_result_type(const DataTypeDecimal<T>& tx, const DataTypeNumber<U>&,
-                                             bool, bool) {
-    if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) {
-        return DataTypeDecimal<T>(max_decimal_precision<T>(), 9);
-    } else {
-        return DataTypeDecimal<T>(max_decimal_precision<T>(), tx.get_scale());
-    }
-}
-
-template <typename T, typename U>
-const DataTypeDecimal<U> decimal_result_type(const DataTypeNumber<T>&, const DataTypeDecimal<U>& ty,
-                                             bool, bool) {
-    if constexpr (IsDecimalV2<T> && IsDecimalV2<U>) {
-        return DataTypeDecimal<U>(max_decimal_precision<U>(), 9);
-    } else {
-        return DataTypeDecimal<U>(max_decimal_precision<U>(), ty.get_scale());
+        return create_decimal(precision, scale, false);
     }
 }
 
diff --git a/be/src/vec/functions/function_binary_arithmetic.h b/be/src/vec/functions/function_binary_arithmetic.h
index 5c98e72486..2a8da748e3 100644
--- a/be/src/vec/functions/function_binary_arithmetic.h
+++ b/be/src/vec/functions/function_binary_arithmetic.h
@@ -730,10 +730,9 @@ public:
                     if constexpr (!std::is_same_v<ResultDataType, InvalidType>) {
                         if constexpr (IsDataTypeDecimal<LeftDataType> &&
                                       IsDataTypeDecimal<RightDataType>) {
-                            ResultDataType result_type = decimal_result_type(
-                                    left, right, OpTraits::is_multiply, OpTraits::is_division);
-                            type_res = std::make_shared<ResultDataType>(result_type.get_precision(),
-                                                                        result_type.get_scale());
+                            type_res = decimal_result_type(left, right, OpTraits::is_multiply,
+                                                           OpTraits::is_division,
+                                                           OpTraits::is_plus_minus);
                         } else if constexpr (IsDataTypeDecimal<LeftDataType>) {
                             type_res = std::make_shared<LeftDataType>(left.get_precision(),
                                                                       left.get_scale());
diff --git a/regression-test/data/decimalv3/test_decimalv3.out b/regression-test/data/decimalv3/test_decimalv3.out
index 1bb8b045c0..f8d56b4c41 100644
--- a/regression-test/data/decimalv3/test_decimalv3.out
+++ b/regression-test/data/decimalv3/test_decimalv3.out
@@ -2,3 +2,6 @@
 -- !decimalv3 --
 100.000000000000000000
 
+-- !decimalv3 --
+100.00000000000000000000
+
diff --git a/regression-test/suites/decimalv3/test_decimalv3.groovy b/regression-test/suites/decimalv3/test_decimalv3.groovy
index 374e554b93..8b8b010240 100644
--- a/regression-test/suites/decimalv3/test_decimalv3.groovy
+++ b/regression-test/suites/decimalv3/test_decimalv3.groovy
@@ -26,4 +26,5 @@ suite("test_decimalv3") {
 	sql "create view test5_v (amout) as select cast(a*b as decimalv3(38,18)) from test5"
 
 	qt_decimalv3 "select * from test5_v"
+	qt_decimalv3 "select cast(a as decimalv3(12,10)) * cast(b as decimalv3(18,10)) from test5"
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org