You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by ga...@apache.org on 2022/12/30 06:02:30 UTC

[doris] branch master updated: [Improvement](decimalv3) Add a config to check overflow for DECIMALV3 (#15463)

This is an automated email from the ASF dual-hosted git repository.

gabriellee pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 520b6d7910 [Improvement](decimalv3) Add a config to check overflow for DECIMALV3 (#15463)
520b6d7910 is described below

commit 520b6d791044425c211fea54c088d5f10893b060
Author: Gabriel <ga...@gmail.com>
AuthorDate: Fri Dec 30 14:02:24 2022 +0800

    [Improvement](decimalv3) Add a config to check overflow for DECIMALV3 (#15463)
---
 be/src/runtime/runtime_state.h                     |   5 +
 be/src/udf/udf_internal.h                          |   8 ++
 be/src/vec/data_types/data_type_decimal.h          |  30 ++++-
 be/src/vec/exprs/vexpr_context.cpp                 |   2 +
 be/src/vec/functions/function_binary_arithmetic.h  | 142 ++++++++++++++-------
 be/src/vec/functions/function_cast.h               | 124 +++++++++++-------
 .../org/apache/doris/analysis/ArithmeticExpr.java  |  19 +--
 .../main/java/org/apache/doris/analysis/Expr.java  |  21 +++
 .../java/org/apache/doris/qe/SessionVariable.java  |  10 ++
 gensrc/thrift/PaloInternalService.thrift           |   1 +
 .../data/datatype_p0/decimalv3/test_overflow.out   |  19 +++
 .../datatype_p0/decimalv3/test_overflow.groovy     |  56 ++++++++
 12 files changed, 335 insertions(+), 102 deletions(-)

diff --git a/be/src/runtime/runtime_state.h b/be/src/runtime/runtime_state.h
index dedef5340d..e9650d0702 100644
--- a/be/src/runtime/runtime_state.h
+++ b/be/src/runtime/runtime_state.h
@@ -135,6 +135,11 @@ public:
                _query_options.enable_function_pushdown;
     }
 
+    bool check_overflow_for_decimal() const {
+        return _query_options.__isset.check_overflow_for_decimal &&
+               _query_options.check_overflow_for_decimal;
+    }
+
     // Create a codegen object in _codegen. No-op if it has already been called.
     // If codegen is enabled for the query, this is created when the runtime
     // state is created. If codegen is disabled for the query, this is created
diff --git a/be/src/udf/udf_internal.h b/be/src/udf/udf_internal.h
index 1bc4fefd0b..67a8ec60e7 100644
--- a/be/src/udf/udf_internal.h
+++ b/be/src/udf/udf_internal.h
@@ -109,6 +109,12 @@ public:
 
     const doris_udf::FunctionContext::TypeDesc& get_return_type() const { return _return_type; }
 
+    const bool check_overflow_for_decimal() const { return _check_overflow_for_decimal; }
+
+    bool set_check_overflow_for_decimal(bool check_overflow_for_decimal) {
+        return _check_overflow_for_decimal = check_overflow_for_decimal;
+    }
+
 private:
     friend class doris_udf::FunctionContext;
     friend class ExprContext;
@@ -181,6 +187,8 @@ private:
     // call that passes the correct AnyVal subclass pointer type.
     std::vector<doris_udf::AnyVal*> _staging_input_vals;
 
+    bool _check_overflow_for_decimal = false;
+
     // Indicates whether this context has been closed. Used for verification/debugging.
     bool _closed;
 
diff --git a/be/src/vec/data_types/data_type_decimal.h b/be/src/vec/data_types/data_type_decimal.h
index c8e08303a1..2213093104 100644
--- a/be/src/vec/data_types/data_type_decimal.h
+++ b/be/src/vec/data_types/data_type_decimal.h
@@ -71,6 +71,11 @@ constexpr Int128 max_decimal_value<Decimal128>() {
     return static_cast<int128_t>(999999999999999999ll) * 100000000000000000ll * 1000ll +
            static_cast<int128_t>(99999999999999999ll) * 1000ll + 999ll;
 }
+template <>
+constexpr Int128 max_decimal_value<Decimal128I>() {
+    return static_cast<int128_t>(999999999999999999ll) * 100000000000000000ll * 1000ll +
+           static_cast<int128_t>(99999999999999999ll) * 1000ll + 999ll;
+}
 
 DataTypePtr create_decimal(UInt64 precision, UInt64 scale, bool use_v2);
 
@@ -291,8 +296,8 @@ constexpr bool IsDataTypeDecimalOrNumber =
 template <typename FromDataType, typename ToDataType>
 inline std::enable_if_t<IsDataTypeDecimal<FromDataType> && IsDataTypeDecimal<ToDataType>,
                         typename ToDataType::FieldType>
-convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_from,
-                 UInt32 scale_to) {
+convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_from, UInt32 scale_to,
+                 UInt8* overflow_flag = nullptr) {
     using FromFieldType = typename FromDataType::FieldType;
     using ToFieldType = typename ToDataType::FieldType;
     using MaxFieldType =
@@ -310,6 +315,9 @@ convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_fro
                 DataTypeDecimal<MaxFieldType>::get_scale_multiplier(scale_to - scale_from);
         if (common::mul_overflow(static_cast<MaxNativeType>(value), converted_value,
                                  converted_value)) {
+            if (overflow_flag) {
+                *overflow_flag = 1;
+            }
             VLOG_DEBUG << "Decimal convert overflow";
             return converted_value < 0
                            ? std::numeric_limits<typename ToFieldType::NativeType>::min()
@@ -322,10 +330,16 @@ convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_fro
 
     if constexpr (sizeof(FromFieldType) > sizeof(ToFieldType)) {
         if (converted_value < std::numeric_limits<typename ToFieldType::NativeType>::min()) {
+            if (overflow_flag) {
+                *overflow_flag = 1;
+            }
             VLOG_DEBUG << "Decimal convert overflow";
             return std::numeric_limits<typename ToFieldType::NativeType>::min();
         }
         if (converted_value > std::numeric_limits<typename ToFieldType::NativeType>::max()) {
+            if (overflow_flag) {
+                *overflow_flag = 1;
+            }
             VLOG_DEBUG << "Decimal convert overflow";
             return std::numeric_limits<typename ToFieldType::NativeType>::max();
         }
@@ -381,12 +395,16 @@ convert_from_decimal(const typename FromDataType::FieldType& value, UInt32 scale
 template <typename FromDataType, typename ToDataType>
 inline std::enable_if_t<IsDataTypeNumber<FromDataType> && IsDataTypeDecimal<ToDataType>,
                         typename ToDataType::FieldType>
-convert_to_decimal(const typename FromDataType::FieldType& value, UInt32 scale) {
+convert_to_decimal(const typename FromDataType::FieldType& value, UInt32 scale,
+                   UInt8* overflow_flag) {
     using FromFieldType = typename FromDataType::FieldType;
     using ToNativeType = typename ToDataType::FieldType::NativeType;
 
     if constexpr (std::is_floating_point_v<FromFieldType>) {
         if (!std::isfinite(value)) {
+            if (overflow_flag) {
+                *overflow_flag = 1;
+            }
             VLOG_DEBUG << "Decimal convert overflow. Cannot convert infinity or NaN to decimal";
             return value < 0 ? std::numeric_limits<ToNativeType>::min()
                              : std::numeric_limits<ToNativeType>::max();
@@ -395,10 +413,16 @@ convert_to_decimal(const typename FromDataType::FieldType& value, UInt32 scale)
         FromFieldType out;
         out = value * ToDataType::get_scale_multiplier(scale);
         if (out <= static_cast<FromFieldType>(std::numeric_limits<ToNativeType>::min())) {
+            if (overflow_flag) {
+                *overflow_flag = 1;
+            }
             VLOG_DEBUG << "Decimal convert overflow. Float is out of Decimal range";
             return std::numeric_limits<ToNativeType>::min();
         }
         if (out >= static_cast<FromFieldType>(std::numeric_limits<ToNativeType>::max())) {
+            if (overflow_flag) {
+                *overflow_flag = 1;
+            }
             VLOG_DEBUG << "Decimal convert overflow. Float is out of Decimal range";
             return std::numeric_limits<ToNativeType>::max();
         }
diff --git a/be/src/vec/exprs/vexpr_context.cpp b/be/src/vec/exprs/vexpr_context.cpp
index ccb1045cb1..9033245202 100644
--- a/be/src/vec/exprs/vexpr_context.cpp
+++ b/be/src/vec/exprs/vexpr_context.cpp
@@ -110,6 +110,8 @@ int VExprContext::register_func(RuntimeState* state, const FunctionContext::Type
                                 int varargs_buffer_size) {
     _fn_contexts.push_back(FunctionContextImpl::create_context(
             state, _pool.get(), return_type, arg_types, varargs_buffer_size, false));
+    _fn_contexts.back()->impl()->set_check_overflow_for_decimal(
+            state->check_overflow_for_decimal());
     return _fn_contexts.size() - 1;
 }
 
diff --git a/be/src/vec/functions/function_binary_arithmetic.h b/be/src/vec/functions/function_binary_arithmetic.h
index 2a8da748e3..d1c0375a55 100644
--- a/be/src/vec/functions/function_binary_arithmetic.h
+++ b/be/src/vec/functions/function_binary_arithmetic.h
@@ -23,6 +23,7 @@
 #include <type_traits>
 
 #include "runtime/decimalv2_value.h"
+#include "udf/udf_internal.h"
 #include "vec/columns/column_const.h"
 #include "vec/columns/column_decimal.h"
 #include "vec/columns/column_nullable.h"
@@ -216,7 +217,8 @@ struct BinaryOperationImpl {
 /// *   no agrs scale. ScaleR = Scale1 + Scale2;
 /// /   first arg scale. ScaleR = Scale1 (scale_a = DecimalType<B>::get_scale()).
 template <typename A, typename B, template <typename, typename> typename Operation,
-          typename ResultType, bool is_to_null_type, bool check_overflow = true>
+          typename ResultType, bool is_to_null_type, bool return_nullable_type,
+          bool check_overflow = true>
 struct DecimalBinaryOperation {
     using OpTraits = OperationTraits<Operation>;
 
@@ -249,12 +251,14 @@ struct DecimalBinaryOperation {
             for (size_t i = 0; i < size; ++i) {
                 c[i] = apply(a[i], b[i], null_map[i]);
             }
-        } else {
-            if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
-                for (size_t i = 0; i < size; ++i) {
-                    c[i] = apply_scaled_div(a[i], b[i], null_map[i]);
-                }
-                return;
+        } else if constexpr (OpTraits::is_division && (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
+            for (size_t i = 0; i < size; ++i) {
+                c[i] = apply_scaled_div(a[i], b[i], null_map[i]);
+            }
+        } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
+                             (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
+            for (size_t i = 0; i < size; ++i) {
+                null_map[i] = apply_op_safely(a[i], b[i], c[i].value);
             }
         }
     }
@@ -281,21 +285,21 @@ struct DecimalBinaryOperation {
             for (size_t i = 0; i < size; ++i) {
                 c[i] = apply_scaled_div(a[i], b, null_map[i]);
             }
-            return;
-        }
-
-        for (size_t i = 0; i < size; ++i) {
-            c[i] = apply(a[i], b, null_map[i]);
+        } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
+                             (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
+            for (size_t i = 0; i < size; ++i) {
+                null_map[i] = apply_op_safely(a[i], b, c[i].value);
+            }
+        } else {
+            for (size_t i = 0; i < size; ++i) {
+                c[i] = apply(a[i], b, null_map[i]);
+            }
         }
     }
 
     static void constant_vector(A a, const typename Traits::ArrayB& b, ArrayC& c) {
         size_t size = b.size();
-        if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
-            for (size_t i = 0; i < size; ++i) {
-                c[i] = apply_scaled_div(a, b[i]);
-            }
-        } else if constexpr (IsDecimalV2<A> || IsDecimalV2<B>) {
+        if constexpr (IsDecimalV2<A> || IsDecimalV2<B>) {
             DecimalV2Value da(a);
             for (size_t i = 0; i < size; ++i) {
                 c[i] = Op::template apply(da, DecimalV2Value(b[i])).value();
@@ -314,33 +318,43 @@ struct DecimalBinaryOperation {
             for (size_t i = 0; i < size; ++i) {
                 c[i] = apply_scaled_div(a, b[i], null_map[i]);
             }
-            return;
-        }
-
-        for (size_t i = 0; i < size; ++i) {
-            c[i] = apply(a, b[i], null_map[i]);
+        } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
+                             (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
+            for (size_t i = 0; i < size; ++i) {
+                null_map[i] = apply_op_safely(a, b[i], c[i].value);
+            }
+        } else {
+            for (size_t i = 0; i < size; ++i) {
+                c[i] = apply(a, b[i], null_map[i]);
+            }
         }
     }
 
-    static ResultType constant_constant(A a, B b) {
-        if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
-            return apply_scaled_div(a, b);
-        }
-        return apply(a, b);
-    }
+    static ResultType constant_constant(A a, B b) { return apply(a, b); }
 
     static ResultType constant_constant(A a, B b, UInt8& is_null) {
         if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
             return apply_scaled_div(a, b, is_null);
+        } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
+                             (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
+            NativeResultType res;
+            is_null = apply_op_safely(a, b, res);
+            return res;
+        } else {
+            return apply(a, b, is_null);
         }
-        return apply(a, b, is_null);
     }
 
     static ColumnPtr adapt_decimal_constant_constant(A a, B b, DataTypePtr res_data_type) {
         auto column_result = ColumnDecimal<ResultType>::create(
                 1, assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale());
 
-        if constexpr (is_to_null_type) {
+        if constexpr (return_nullable_type && !is_to_null_type &&
+                      ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
+                       IsDecimalV2<B>)) {
+            LOG(FATAL) << "Invalid function type!";
+            return column_result;
+        } else if constexpr (return_nullable_type || is_to_null_type) {
             auto null_map = ColumnUInt8::create(1, 0);
             column_result->get_element(0) = constant_constant(a, b, null_map->get_element(0));
             return ColumnNullable::create(std::move(column_result), std::move(null_map));
@@ -358,7 +372,12 @@ struct DecimalBinaryOperation {
                 assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale());
         DCHECK(column_left_ptr != nullptr);
 
-        if constexpr (is_to_null_type) {
+        if constexpr (return_nullable_type && !is_to_null_type &&
+                      ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
+                       IsDecimalV2<B>)) {
+            LOG(FATAL) << "Invalid function type!";
+            return column_result;
+        } else if constexpr (return_nullable_type || is_to_null_type) {
             auto null_map = ColumnUInt8::create(column_left->size(), 0);
             vector_constant(column_left_ptr->get_data(), b, column_result->get_data(),
                             null_map->get_data());
@@ -377,7 +396,12 @@ struct DecimalBinaryOperation {
                 assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale());
         DCHECK(column_right_ptr != nullptr);
 
-        if constexpr (is_to_null_type) {
+        if constexpr (return_nullable_type && !is_to_null_type &&
+                      ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
+                       IsDecimalV2<B>)) {
+            LOG(FATAL) << "Invalid function type!";
+            return column_result;
+        } else if constexpr (return_nullable_type || is_to_null_type) {
             auto null_map = ColumnUInt8::create(column_right->size(), 0);
             constant_vector(a, column_right_ptr->get_data(), column_result->get_data(),
                             null_map->get_data());
@@ -398,7 +422,12 @@ struct DecimalBinaryOperation {
                 assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale());
         DCHECK(column_left_ptr != nullptr && column_right_ptr != nullptr);
 
-        if constexpr (is_to_null_type) {
+        if constexpr (return_nullable_type && !is_to_null_type &&
+                      ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
+                       IsDecimalV2<B>)) {
+            LOG(FATAL) << "Invalid function type!";
+            return column_result;
+        } else if constexpr (return_nullable_type || is_to_null_type) {
             auto null_map = ColumnUInt8::create(column_result->size(), 0);
             vector_vector(column_left_ptr->get_data(), column_right_ptr->get_data(),
                           column_result->get_data(), null_map->get_data());
@@ -483,6 +512,12 @@ private:
                                              UInt8& is_null) {
         return apply(a, b, is_null);
     }
+
+    static UInt8 apply_op_safely(NativeResultType a, NativeResultType b, NativeResultType& c) {
+        if constexpr (OpTraits::is_multiply || OpTraits::is_plus_minus) {
+            return Op::template apply(a, b, c);
+        }
+    }
 };
 
 /// Used to indicate undefined operation
@@ -568,7 +603,8 @@ struct BinaryOperationTraits {
 };
 
 template <typename LeftDataType, typename RightDataType, typename ExpectedResultDataType,
-          template <typename, typename> class Operation, bool is_to_null_type>
+          template <typename, typename> class Operation, bool is_to_null_type,
+          bool return_nullable_type>
 struct ConstOrVectorAdapter {
     static constexpr bool result_is_decimal =
             IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>;
@@ -580,7 +616,8 @@ struct ConstOrVectorAdapter {
 
     using OperationImpl = std::conditional_t<
             IsDataTypeDecimal<ResultDataType>,
-            DecimalBinaryOperation<A, B, Operation, ResultType, is_to_null_type>,
+            DecimalBinaryOperation<A, B, Operation, ResultType, is_to_null_type,
+                                   return_nullable_type>,
             BinaryOperationImpl<A, B, Operation<A, B>, is_to_null_type, ResultType>>;
 
     static ColumnPtr execute(ColumnPtr column_left, ColumnPtr column_right,
@@ -774,6 +811,7 @@ public:
             right_generic =
                     static_cast<const DataTypeNullable*>(right_generic)->get_nested_type().get();
         }
+        bool result_is_nullable = context->impl()->check_overflow_for_decimal();
         if (result_generic->is_nullable()) {
             result_generic =
                     static_cast<const DataTypeNullable*>(result_generic)->get_nested_type().get();
@@ -795,15 +833,31 @@ public:
                                      ResultDataType>)&&(IsDataTypeDecimal<ExpectedResultDataType> ==
                                                         (IsDataTypeDecimal<LeftDataType> ||
                                                          IsDataTypeDecimal<RightDataType>))) {
-                        auto column_result = ConstOrVectorAdapter<
-                                LeftDataType, RightDataType,
-                                std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>,
-                                                   ExpectedResultDataType, ResultDataType>,
-                                Operation, is_to_null_type>::
-                                execute(block.get_by_position(arguments[0]).column,
-                                        block.get_by_position(arguments[1]).column, left, right,
-                                        remove_nullable(block.get_by_position(result).type));
-                        block.replace_by_position(result, std::move(column_result));
+                        if (result_is_nullable) {
+                            auto column_result = ConstOrVectorAdapter<
+                                    LeftDataType, RightDataType,
+                                    std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>,
+                                                       ExpectedResultDataType, ResultDataType>,
+                                    Operation, is_to_null_type,
+                                    true>::execute(block.get_by_position(arguments[0]).column,
+                                                   block.get_by_position(arguments[1]).column, left,
+                                                   right,
+                                                   remove_nullable(
+                                                           block.get_by_position(result).type));
+                            block.replace_by_position(result, std::move(column_result));
+                        } else {
+                            auto column_result = ConstOrVectorAdapter<
+                                    LeftDataType, RightDataType,
+                                    std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>,
+                                                       ExpectedResultDataType, ResultDataType>,
+                                    Operation, is_to_null_type,
+                                    false>::execute(block.get_by_position(arguments[0]).column,
+                                                    block.get_by_position(arguments[1]).column,
+                                                    left, right,
+                                                    remove_nullable(
+                                                            block.get_by_position(result).type));
+                            block.replace_by_position(result, std::move(column_result));
+                        }
                         return true;
                     }
                     return false;
diff --git a/be/src/vec/functions/function_cast.h b/be/src/vec/functions/function_cast.h
index 90edd3906b..e3baaecdd2 100644
--- a/be/src/vec/functions/function_cast.h
+++ b/be/src/vec/functions/function_cast.h
@@ -22,6 +22,7 @@
 
 #include <fmt/format.h>
 
+#include "udf/udf_internal.h"
 #include "vec/columns/column_array.h"
 #include "vec/columns/column_const.h"
 #include "vec/columns/column_nullable.h"
@@ -72,7 +73,7 @@ struct ConvertImpl {
 
     template <typename Additions = void*>
     static Status execute(Block& block, const ColumnNumbers& arguments, size_t result,
-                          size_t /*input_rows_count*/,
+                          size_t /*input_rows_count*/, bool check_overflow [[maybe_unused]] = false,
                           Additions additions [[maybe_unused]] = Additions()) {
         const ColumnWithTypeAndName& named_from = block.get_by_position(arguments[0]);
 
@@ -96,37 +97,50 @@ struct ConvertImpl {
             if constexpr (IsDataTypeDecimal<ToDataType>) {
                 UInt32 scale = additions;
                 col_to = ColVecTo::create(0, scale);
-            } else
+            } else {
                 col_to = ColVecTo::create();
+            }
 
             const auto& vec_from = col_from->get_data();
             auto& vec_to = col_to->get_data();
             size_t size = vec_from.size();
             vec_to.resize(size);
 
-            for (size_t i = 0; i < size; ++i) {
-                if constexpr (IsDataTypeDecimal<FromDataType> || IsDataTypeDecimal<ToDataType>) {
-                    if constexpr (IsDataTypeDecimal<FromDataType> && IsDataTypeDecimal<ToDataType>)
+            if constexpr (IsDataTypeDecimal<FromDataType> || IsDataTypeDecimal<ToDataType>) {
+                ColumnUInt8::MutablePtr col_null_map_to = nullptr;
+                UInt8* vec_null_map_to = nullptr;
+                if (check_overflow) {
+                    col_null_map_to = ColumnUInt8::create(size, 0);
+                    vec_null_map_to = col_null_map_to->get_data().data();
+                }
+                for (size_t i = 0; i < size; ++i) {
+                    if constexpr (IsDataTypeDecimal<FromDataType> &&
+                                  IsDataTypeDecimal<ToDataType>) {
                         vec_to[i] = convert_decimals<FromDataType, ToDataType>(
-                                vec_from[i], vec_from.get_scale(), vec_to.get_scale());
-                    else if constexpr (IsDataTypeDecimal<FromDataType> &&
-                                       IsDataTypeNumber<ToDataType>)
+                                vec_from[i], vec_from.get_scale(), vec_to.get_scale(),
+                                vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to);
+                    } else if constexpr (IsDataTypeDecimal<FromDataType> &&
+                                         IsDataTypeNumber<ToDataType>) {
                         vec_to[i] = convert_from_decimal<FromDataType, ToDataType>(
                                 vec_from[i], vec_from.get_scale());
-                    else if constexpr (IsDataTypeNumber<FromDataType> &&
-                                       IsDataTypeDecimal<ToDataType>)
+                    } else if constexpr (IsDataTypeNumber<FromDataType> &&
+                                         IsDataTypeDecimal<ToDataType>) {
                         vec_to[i] = convert_to_decimal<FromDataType, ToDataType>(
-                                vec_from[i], vec_to.get_scale());
-                    else if constexpr (IsTimeType<FromDataType> && IsDataTypeDecimal<ToDataType>) {
+                                vec_from[i], vec_to.get_scale(),
+                                vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to);
+                    } else if constexpr (IsTimeType<FromDataType> &&
+                                         IsDataTypeDecimal<ToDataType>) {
                         vec_to[i] = convert_to_decimal<DataTypeInt64, ToDataType>(
                                 reinterpret_cast<const VecDateTimeValue&>(vec_from[i]).to_int64(),
-                                vec_to.get_scale());
+                                vec_to.get_scale(),
+                                vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to);
                     } else if constexpr (IsDateV2Type<FromDataType> &&
                                          IsDataTypeDecimal<ToDataType>) {
                         vec_to[i] = convert_to_decimal<DataTypeUInt32, ToDataType>(
                                 reinterpret_cast<const DateV2Value<DateV2ValueType>&>(vec_from[i])
                                         .to_date_int_val(),
-                                vec_to.get_scale());
+                                vec_to.get_scale(),
+                                vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to);
                     } else if constexpr (IsDateTimeV2Type<FromDataType> &&
                                          IsDataTypeDecimal<ToDataType>) {
                         // TODO: should we consider the scale of datetimev2?
@@ -134,9 +148,21 @@ struct ConvertImpl {
                                 reinterpret_cast<const DateV2Value<DateTimeV2ValueType>&>(
                                         vec_from[i])
                                         .to_date_int_val(),
-                                vec_to.get_scale());
+                                vec_to.get_scale(),
+                                vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to);
                     }
-                } else if constexpr (IsTimeType<FromDataType>) {
+                }
+                if (check_overflow) {
+                    block.replace_by_position(
+                            result,
+                            ColumnNullable::create(std::move(col_to), std::move(col_null_map_to)));
+                } else {
+                    block.replace_by_position(result, std::move(col_to));
+                }
+
+                return Status::OK();
+            } else if constexpr (IsTimeType<FromDataType>) {
+                for (size_t i = 0; i < size; ++i) {
                     if constexpr (IsTimeType<ToDataType>) {
                         vec_to[i] = static_cast<ToFieldType>(vec_from[i]);
                         if constexpr (IsDateTimeType<ToDataType>) {
@@ -152,7 +178,9 @@ struct ConvertImpl {
                         vec_to[i] =
                                 reinterpret_cast<const VecDateTimeValue&>(vec_from[i]).to_int64();
                     }
-                } else if constexpr (IsTimeV2Type<FromDataType>) {
+                }
+            } else if constexpr (IsTimeV2Type<FromDataType>) {
+                for (size_t i = 0; i < size; ++i) {
                     if constexpr (IsTimeV2Type<ToDataType>) {
                         if constexpr (IsDateTimeV2Type<ToDataType> && IsDateV2Type<FromDataType>) {
                             DataTypeDateV2::cast_to_date_time_v2(vec_from[i], vec_to[i]);
@@ -189,7 +217,9 @@ struct ConvertImpl {
                                                 .to_int64();
                         }
                     }
-                } else {
+                }
+            } else {
+                for (size_t i = 0; i < size; ++i) {
                     vec_to[i] = static_cast<ToFieldType>(vec_from[i]);
                 }
             }
@@ -547,7 +577,7 @@ struct ConvertImpl<DataTypeString, ToDataType, Name> {
     template <typename Additions = void*>
 
     static Status execute(Block& block, const ColumnNumbers& arguments, size_t result,
-                          size_t /*input_rows_count*/,
+                          size_t /*input_rows_count*/, bool check_overflow [[maybe_unused]] = false,
                           Additions additions [[maybe_unused]] = Additions()) {
         return Status::RuntimeError("not support convert from string");
     }
@@ -832,19 +862,6 @@ public:
 
     Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) override {
-        return executeInternal(block, arguments, result, input_rows_count);
-    }
-
-    bool has_information_about_monotonicity() const override { return Monotonic::has(); }
-
-    Monotonicity get_monotonicity_for_range(const IDataType& type, const Field& left,
-                                            const Field& right) const override {
-        return Monotonic::get(type, left, right);
-    }
-
-private:
-    Status executeInternal(Block& block, const ColumnNumbers& arguments, size_t result,
-                           size_t input_rows_count) {
         if (!arguments.size()) {
             return Status::RuntimeError("Function {} expects at least 1 arguments", get_name());
         }
@@ -873,13 +890,15 @@ private:
                     UInt32 scale = extract_to_decimal_scale(scale_column);
 
                     ret_status = ConvertImpl<LeftDataType, RightDataType, Name>::execute(
-                            block, arguments, result, input_rows_count, scale);
+                            block, arguments, result, input_rows_count,
+                            context->impl()->check_overflow_for_decimal(), scale);
                 } else if constexpr (IsDataTypeDateTimeV2<RightDataType>) {
                     const ColumnWithTypeAndName& scale_column = block.get_by_position(result);
                     auto type =
                             check_and_get_data_type<DataTypeDateTimeV2>(scale_column.type.get());
                     ret_status = ConvertImpl<LeftDataType, RightDataType, Name>::execute(
-                            block, arguments, result, input_rows_count, type->get_scale());
+                            block, arguments, result, input_rows_count,
+                            context->impl()->check_overflow_for_decimal(), type->get_scale());
                 } else {
                     ret_status = ConvertImpl<LeftDataType, RightDataType, Name>::execute(
                             block, arguments, result, input_rows_count);
@@ -896,6 +915,13 @@ private:
             return ret_status;
         }
     }
+
+    bool has_information_about_monotonicity() const override { return Monotonic::has(); }
+
+    Monotonicity get_monotonicity_for_range(const IDataType& type, const Field& left,
+                                            const Field& right) const override {
+        return Monotonic::get(type, left, right);
+    }
 };
 
 using FunctionToUInt8 = FunctionConvert<DataTypeUInt8, NameToUInt8, ToNumberMonotonicity<UInt8>>;
@@ -1055,7 +1081,7 @@ struct ConvertThroughParsing {
 
     template <typename Additions = void*>
     static Status execute(Block& block, const ColumnNumbers& arguments, size_t result,
-                          size_t input_rows_count,
+                          size_t input_rows_count, bool check_overflow [[maybe_unused]] = false,
                           Additions additions [[maybe_unused]] = Additions()) {
         using ColVecTo = std::conditional_t<IsDecimalNumber<ToFieldType>,
                                             ColumnDecimal<ToFieldType>, ColumnVector<ToFieldType>>;
@@ -1254,7 +1280,8 @@ public:
                                 const ColumnNumbers& /*arguments*/,
                                 size_t /*result*/) const override {
         return std::make_shared<PreparedFunctionCast>(
-                prepare_unpack_dictionaries(get_argument_types()[0], get_return_type()), name);
+                prepare_unpack_dictionaries(context, get_argument_types()[0], get_return_type()),
+                name);
     }
 
     String get_name() const override { return name; }
@@ -1347,7 +1374,8 @@ private:
                         using RightDataType = typename Types::RightType;
 
                         ConvertImpl<LeftDataType, RightDataType, NameCast>::execute(
-                                block, arguments, result, input_rows_count, scale);
+                                block, arguments, result, input_rows_count,
+                                context->impl()->check_overflow_for_decimal(), scale);
                         return true;
                     });
 
@@ -1396,7 +1424,7 @@ private:
         return create_unsupport_wrapper(error_msg);
     }
 
-    WrapperType create_array_wrapper(const DataTypePtr& from_type_untyped,
+    WrapperType create_array_wrapper(FunctionContext* context, const DataTypePtr& from_type_untyped,
                                      const DataTypeArray& to_type) const {
         /// Conversion from String through parsing.
         if (check_and_get_data_type<DataTypeString>(from_type_untyped.get())) {
@@ -1425,7 +1453,8 @@ private:
         const DataTypePtr& to_nested_type = to_type.get_nested_type();
 
         /// Prepare nested type conversion
-        const auto nested_function = prepare_unpack_dictionaries(from_nested_type, to_nested_type);
+        const auto nested_function =
+                prepare_unpack_dictionaries(context, from_nested_type, to_nested_type);
 
         return [nested_function, from_nested_type, to_nested_type](
                        FunctionContext* context, Block& block, const ColumnNumbers& arguments,
@@ -1513,7 +1542,7 @@ private:
         }
     }
 
-    WrapperType prepare_unpack_dictionaries(const DataTypePtr& from_type,
+    WrapperType prepare_unpack_dictionaries(FunctionContext* context, const DataTypePtr& from_type,
                                             const DataTypePtr& to_type) const {
         const auto& from_nested = from_type;
         const auto& to_nested = to_type;
@@ -1534,18 +1563,20 @@ private:
 
         constexpr bool skip_not_null_check = false;
 
-        auto wrapper = prepare_remove_nullable(from_nested, to_nested, skip_not_null_check);
+        auto wrapper =
+                prepare_remove_nullable(context, from_nested, to_nested, skip_not_null_check);
 
         return wrapper;
     }
 
-    WrapperType prepare_remove_nullable(const DataTypePtr& from_type, const DataTypePtr& to_type,
+    WrapperType prepare_remove_nullable(FunctionContext* context, const DataTypePtr& from_type,
+                                        const DataTypePtr& to_type,
                                         bool skip_not_null_check) const {
         /// Determine whether pre-processing and/or post-processing must take place during conversion.
         bool source_is_nullable = from_type->is_nullable();
         bool result_is_nullable = to_type->is_nullable();
 
-        auto wrapper = prepare_impl(remove_nullable(from_type), remove_nullable(to_type),
+        auto wrapper = prepare_impl(context, remove_nullable(from_type), remove_nullable(to_type),
                                     result_is_nullable);
 
         if (result_is_nullable) {
@@ -1620,8 +1651,8 @@ private:
 
     /// 'from_type' and 'to_type' are nested types in case of Nullable.
     /// 'requested_result_is_nullable' is true if CAST to Nullable type is requested.
-    WrapperType prepare_impl(const DataTypePtr& from_type, const DataTypePtr& to_type,
-                             bool requested_result_is_nullable) const {
+    WrapperType prepare_impl(FunctionContext* context, const DataTypePtr& from_type,
+                             const DataTypePtr& to_type, bool requested_result_is_nullable) const {
         if (from_type->equals(*to_type))
             return create_identity_wrapper(from_type);
         else if (WhichDataType(from_type).is_nothing())
@@ -1679,7 +1710,8 @@ private:
         case TypeIndex::String:
             return create_string_wrapper(from_type);
         case TypeIndex::Array:
-            return create_array_wrapper(from_type, static_cast<const DataTypeArray&>(*to_type));
+            return create_array_wrapper(context, from_type,
+                                        static_cast<const DataTypeArray&>(*to_type));
         default:
             break;
         }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java
index e5502fe635..99b6e39e3b 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java
@@ -21,6 +21,7 @@
 package org.apache.doris.analysis;
 
 import org.apache.doris.catalog.Function;
+import org.apache.doris.catalog.Function.NullableMode;
 import org.apache.doris.catalog.FunctionSet;
 import org.apache.doris.catalog.PrimitiveType;
 import org.apache.doris.catalog.ScalarFunction;
@@ -107,12 +108,13 @@ public class ArithmeticExpr extends Expr {
 
     public static void initBuiltins(FunctionSet functionSet) {
         for (Type t : Type.getNumericTypes()) {
+            NullableMode mode = t.isDecimalV3() ? NullableMode.CUSTOM : NullableMode.DEPEND_ON_ARGUMENT;
             functionSet.addBuiltin(ScalarFunction.createBuiltinOperator(
-                    Operator.MULTIPLY.getName(), Lists.newArrayList(t, t), t));
+                    Operator.MULTIPLY.getName(), Lists.newArrayList(t, t), t, mode));
             functionSet.addBuiltin(ScalarFunction.createBuiltinOperator(
-                    Operator.ADD.getName(), Lists.newArrayList(t, t), t));
+                    Operator.ADD.getName(), Lists.newArrayList(t, t), t, mode));
             functionSet.addBuiltin(ScalarFunction.createBuiltinOperator(
-                    Operator.SUBTRACT.getName(), Lists.newArrayList(t, t), t));
+                    Operator.SUBTRACT.getName(), Lists.newArrayList(t, t), t, mode));
         }
         functionSet.addBuiltin(ScalarFunction.createBuiltinOperator(
                 Operator.DIVIDE.getName(),
@@ -173,15 +175,14 @@ public class ArithmeticExpr extends Expr {
             for (int j = 0; j < Type.getNumericTypes().size(); j++) {
                 Type t2 = Type.getNumericTypes().get(j);
 
+                Type retType = Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false));
+                NullableMode mode = retType.isDecimalV3() ? NullableMode.CUSTOM : NullableMode.DEPEND_ON_ARGUMENT;
                 functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator(
-                        Operator.MULTIPLY.getName(), Lists.newArrayList(t1, t2),
-                        Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false))));
+                        Operator.MULTIPLY.getName(), Lists.newArrayList(t1, t2), retType, mode));
                 functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator(
-                        Operator.ADD.getName(), Lists.newArrayList(t1, t2),
-                        Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false))));
+                        Operator.ADD.getName(), Lists.newArrayList(t1, t2), retType, mode));
                 functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator(
-                        Operator.SUBTRACT.getName(), Lists.newArrayList(t1, t2),
-                        Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false))));
+                        Operator.SUBTRACT.getName(), Lists.newArrayList(t1, t2), retType, mode));
             }
         }
 
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
index 79df8f3395..9c29c7ffee 100755
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
@@ -20,6 +20,7 @@
 
 package org.apache.doris.analysis;
 
+import org.apache.doris.analysis.ArithmeticExpr.Operator;
 import org.apache.doris.catalog.Env;
 import org.apache.doris.catalog.Function;
 import org.apache.doris.catalog.FunctionSet;
@@ -31,6 +32,7 @@ import org.apache.doris.common.Config;
 import org.apache.doris.common.TreeNode;
 import org.apache.doris.common.io.Writable;
 import org.apache.doris.common.util.VectorizedUtil;
+import org.apache.doris.qe.ConnectContext;
 import org.apache.doris.statistics.ExprStats;
 import org.apache.doris.thrift.TExpr;
 import org.apache.doris.thrift.TExprNode;
@@ -2036,6 +2038,25 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
         if (fn.functionName().equalsIgnoreCase("concat_ws")) {
             return children.get(0).isNullable();
         }
+        if (fn.functionName().equalsIgnoreCase(Operator.MULTIPLY.getName())
+                && fn.getReturnType().isDecimalV3()) {
+            if (ConnectContext.get() != null
+                    && ConnectContext.get().getSessionVariable().checkOverflowForDecimal()) {
+                return true;
+            } else {
+                return hasNullableChild();
+            }
+        }
+        if ((fn.functionName().equalsIgnoreCase(Operator.ADD.getName())
+                || fn.functionName().equalsIgnoreCase(Operator.SUBTRACT.getName()))
+                && fn.getReturnType().isDecimalV3()) {
+            if (ConnectContext.get() != null
+                    && ConnectContext.get().getSessionVariable().checkOverflowForDecimal()) {
+                return true;
+            } else {
+                return hasNullableChild();
+            }
+        }
         return true;
     }
 
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
index 7faa492295..dcea629834 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
@@ -189,6 +189,8 @@ public class SessionVariable implements Serializable, Writable {
 
     public static final String ENABLE_PROJECTION = "enable_projection";
 
+    public static final String CHECK_OVERFLOW_FOR_DECIMAL = "check_overflow_for_decimal";
+
     public static final String TRIM_TAILING_SPACES_FOR_EXTERNAL_TABLE_QUERY
             = "trim_tailing_spaces_for_external_table_query";
 
@@ -542,6 +544,9 @@ public class SessionVariable implements Serializable, Writable {
     @VariableMgr.VarAttr(name = ENABLE_PROJECTION)
     private boolean enableProjection = true;
 
+    @VariableMgr.VarAttr(name = CHECK_OVERFLOW_FOR_DECIMAL)
+    private boolean checkOverflowForDecimal = false;
+
     /**
      * as the new optimizer is not mature yet, use this var
      * to control whether to use new optimizer, remove it when
@@ -1235,6 +1240,10 @@ public class SessionVariable implements Serializable, Writable {
         return enableProjection;
     }
 
+    public boolean checkOverflowForDecimal() {
+        return checkOverflowForDecimal;
+    }
+
     public boolean isTrimTailingSpacesForExternalTableQuery() {
         return trimTailingSpacesForExternalTableQuery;
     }
@@ -1368,6 +1377,7 @@ public class SessionVariable implements Serializable, Writable {
         }
 
         tResult.setEnableFunctionPushdown(enableFunctionPushdown);
+        tResult.setCheckOverflowForDecimal(checkOverflowForDecimal);
         tResult.setFragmentTransmissionCompressionCodec(fragmentTransmissionCompressionCodec);
         tResult.setEnableLocalExchange(enableLocalExchange);
         tResult.setEnableNewShuffleHashMethod(enableNewShuffleHashMethod);
diff --git a/gensrc/thrift/PaloInternalService.thrift b/gensrc/thrift/PaloInternalService.thrift
index 32d6721bb3..bb241c4aa0 100644
--- a/gensrc/thrift/PaloInternalService.thrift
+++ b/gensrc/thrift/PaloInternalService.thrift
@@ -187,6 +187,7 @@ struct TQueryOptions {
   55: optional bool enable_pipeline_engine = false
 
   56: optional i32 repeat_max_num = 0
+  57: optional bool check_overflow_for_decimal = false
 }
     
 
diff --git a/regression-test/data/datatype_p0/decimalv3/test_overflow.out b/regression-test/data/datatype_p0/decimalv3/test_overflow.out
new file mode 100644
index 0000000000..c9b9873cd7
--- /dev/null
+++ b/regression-test/data/datatype_p0/decimalv3/test_overflow.out
@@ -0,0 +1,19 @@
+-- This file is automatically generated. You should know what you did if you want to edit this
+-- !select_all --
+11111111111111111111.100000000000000000	11111111111111111111.200000000000000000	11111111111111111111.300000000000000000	1.1000000000000000000000000000000000000	1.2000000000000000000000000000000000000	1.3000000000000000000000000000000000000	9
+
+-- !select_check_overflow1 --
+\N	\N	\N	99999999999999999999.900000000000000000	\N
+
+-- !select_check_overflow2 --
+1.1000000000000000000000000000000000000	111111111111111111111.000000000000000000	\N
+
+-- !select_check_overflow3 --
+11111111111111111111.100000000000000000	\N
+
+-- !select_not_check_overflow1 --
+99.999999999999999999999999999999999999	99.999999999999999999999999999999999999	1.1111111111111111E21	99999999999999999999.900000000000000000	99999999999999999999.999999999999999999
+
+-- !select_not_check_overflow2 --
+1.1000000000000000000000000000000000000	111111111111111111111.000000000000000000	-15.9141183460469231731687303715884105728
+
diff --git a/regression-test/suites/datatype_p0/decimalv3/test_overflow.groovy b/regression-test/suites/datatype_p0/decimalv3/test_overflow.groovy
new file mode 100644
index 0000000000..01de2ea498
--- /dev/null
+++ b/regression-test/suites/datatype_p0/decimalv3/test_overflow.groovy
@@ -0,0 +1,56 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_overflow") {
+
+    def table1 = "test_overflow"
+
+    sql "drop table if exists ${table1}"
+
+    sql """
+    CREATE TABLE IF NOT EXISTS test_overflow (
+      `k1` decimalv3(38, 18) NULL COMMENT "",
+      `k2` decimalv3(38, 18) NULL COMMENT "",
+      `k3` decimalv3(38, 18) NULL COMMENT "",
+      `v1` decimalv3(38, 37) NULL COMMENT "",
+      `v2` decimalv3(38, 37) NULL COMMENT "",
+      `v3` decimalv3(38, 37) NULL COMMENT "",
+      `v4` INT NULL COMMENT ""
+    ) ENGINE=OLAP
+    COMMENT "OLAP"
+    DISTRIBUTED BY HASH(`k1`, `k2`, `k3`) BUCKETS 8
+    PROPERTIES (
+    "replication_allocation" = "tag.location.default: 1",
+    "in_memory" = "false",
+    "storage_format" = "V2"
+    )
+    """
+
+    sql """insert into test_overflow values(11111111111111111111.1,11111111111111111111.2,11111111111111111111.3, 1.1,1.2,1.3,9)
+    """
+    qt_select_all "select * from test_overflow order by k1"
+
+    sql " SET check_overflow_for_decimal = true; "
+    qt_select_check_overflow1 "select k1 * k2, k1 * k3, k1 * k2 * k3, k1 * v4, k1*50 from test_overflow;"
+    qt_select_check_overflow2 "select v1, k1*10, v1 +k1*10 from test_overflow"
+    qt_select_check_overflow3 "select `k1`, cast (`k1` as DECIMALV3(38, 36)) from test_overflow;"
+
+    sql " SET check_overflow_for_decimal = false; "
+    qt_select_not_check_overflow1 "select k1 * k2, k1 * k3, k1 * k2 * k3, k1 * v4, k1*50 from test_overflow;"
+    qt_select_not_check_overflow2 "select v1, k1*10, v1 +k1*10 from test_overflow"
+    sql "drop table if exists ${table1}"
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org