You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by ga...@apache.org on 2022/12/30 06:02:30 UTC
[doris] branch master updated: [Improvement](decimalv3) Add a config to check overflow for DECIMALV3 (#15463)
This is an automated email from the ASF dual-hosted git repository.
gabriellee pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 520b6d7910 [Improvement](decimalv3) Add a config to check overflow for DECIMALV3 (#15463)
520b6d7910 is described below
commit 520b6d791044425c211fea54c088d5f10893b060
Author: Gabriel <ga...@gmail.com>
AuthorDate: Fri Dec 30 14:02:24 2022 +0800
[Improvement](decimalv3) Add a config to check overflow for DECIMALV3 (#15463)
---
be/src/runtime/runtime_state.h | 5 +
be/src/udf/udf_internal.h | 8 ++
be/src/vec/data_types/data_type_decimal.h | 30 ++++-
be/src/vec/exprs/vexpr_context.cpp | 2 +
be/src/vec/functions/function_binary_arithmetic.h | 142 ++++++++++++++-------
be/src/vec/functions/function_cast.h | 124 +++++++++++-------
.../org/apache/doris/analysis/ArithmeticExpr.java | 19 +--
.../main/java/org/apache/doris/analysis/Expr.java | 21 +++
.../java/org/apache/doris/qe/SessionVariable.java | 10 ++
gensrc/thrift/PaloInternalService.thrift | 1 +
.../data/datatype_p0/decimalv3/test_overflow.out | 19 +++
.../datatype_p0/decimalv3/test_overflow.groovy | 56 ++++++++
12 files changed, 335 insertions(+), 102 deletions(-)
diff --git a/be/src/runtime/runtime_state.h b/be/src/runtime/runtime_state.h
index dedef5340d..e9650d0702 100644
--- a/be/src/runtime/runtime_state.h
+++ b/be/src/runtime/runtime_state.h
@@ -135,6 +135,11 @@ public:
_query_options.enable_function_pushdown;
}
+ bool check_overflow_for_decimal() const {
+ return _query_options.__isset.check_overflow_for_decimal &&
+ _query_options.check_overflow_for_decimal;
+ }
+
// Create a codegen object in _codegen. No-op if it has already been called.
// If codegen is enabled for the query, this is created when the runtime
// state is created. If codegen is disabled for the query, this is created
diff --git a/be/src/udf/udf_internal.h b/be/src/udf/udf_internal.h
index 1bc4fefd0b..67a8ec60e7 100644
--- a/be/src/udf/udf_internal.h
+++ b/be/src/udf/udf_internal.h
@@ -109,6 +109,12 @@ public:
const doris_udf::FunctionContext::TypeDesc& get_return_type() const { return _return_type; }
+ const bool check_overflow_for_decimal() const { return _check_overflow_for_decimal; }
+
+ bool set_check_overflow_for_decimal(bool check_overflow_for_decimal) {
+ return _check_overflow_for_decimal = check_overflow_for_decimal;
+ }
+
private:
friend class doris_udf::FunctionContext;
friend class ExprContext;
@@ -181,6 +187,8 @@ private:
// call that passes the correct AnyVal subclass pointer type.
std::vector<doris_udf::AnyVal*> _staging_input_vals;
+ bool _check_overflow_for_decimal = false;
+
// Indicates whether this context has been closed. Used for verification/debugging.
bool _closed;
diff --git a/be/src/vec/data_types/data_type_decimal.h b/be/src/vec/data_types/data_type_decimal.h
index c8e08303a1..2213093104 100644
--- a/be/src/vec/data_types/data_type_decimal.h
+++ b/be/src/vec/data_types/data_type_decimal.h
@@ -71,6 +71,11 @@ constexpr Int128 max_decimal_value<Decimal128>() {
return static_cast<int128_t>(999999999999999999ll) * 100000000000000000ll * 1000ll +
static_cast<int128_t>(99999999999999999ll) * 1000ll + 999ll;
}
+template <>
+constexpr Int128 max_decimal_value<Decimal128I>() {
+ return static_cast<int128_t>(999999999999999999ll) * 100000000000000000ll * 1000ll +
+ static_cast<int128_t>(99999999999999999ll) * 1000ll + 999ll;
+}
DataTypePtr create_decimal(UInt64 precision, UInt64 scale, bool use_v2);
@@ -291,8 +296,8 @@ constexpr bool IsDataTypeDecimalOrNumber =
template <typename FromDataType, typename ToDataType>
inline std::enable_if_t<IsDataTypeDecimal<FromDataType> && IsDataTypeDecimal<ToDataType>,
typename ToDataType::FieldType>
-convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_from,
- UInt32 scale_to) {
+convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_from, UInt32 scale_to,
+ UInt8* overflow_flag = nullptr) {
using FromFieldType = typename FromDataType::FieldType;
using ToFieldType = typename ToDataType::FieldType;
using MaxFieldType =
@@ -310,6 +315,9 @@ convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_fro
DataTypeDecimal<MaxFieldType>::get_scale_multiplier(scale_to - scale_from);
if (common::mul_overflow(static_cast<MaxNativeType>(value), converted_value,
converted_value)) {
+ if (overflow_flag) {
+ *overflow_flag = 1;
+ }
VLOG_DEBUG << "Decimal convert overflow";
return converted_value < 0
? std::numeric_limits<typename ToFieldType::NativeType>::min()
@@ -322,10 +330,16 @@ convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_fro
if constexpr (sizeof(FromFieldType) > sizeof(ToFieldType)) {
if (converted_value < std::numeric_limits<typename ToFieldType::NativeType>::min()) {
+ if (overflow_flag) {
+ *overflow_flag = 1;
+ }
VLOG_DEBUG << "Decimal convert overflow";
return std::numeric_limits<typename ToFieldType::NativeType>::min();
}
if (converted_value > std::numeric_limits<typename ToFieldType::NativeType>::max()) {
+ if (overflow_flag) {
+ *overflow_flag = 1;
+ }
VLOG_DEBUG << "Decimal convert overflow";
return std::numeric_limits<typename ToFieldType::NativeType>::max();
}
@@ -381,12 +395,16 @@ convert_from_decimal(const typename FromDataType::FieldType& value, UInt32 scale
template <typename FromDataType, typename ToDataType>
inline std::enable_if_t<IsDataTypeNumber<FromDataType> && IsDataTypeDecimal<ToDataType>,
typename ToDataType::FieldType>
-convert_to_decimal(const typename FromDataType::FieldType& value, UInt32 scale) {
+convert_to_decimal(const typename FromDataType::FieldType& value, UInt32 scale,
+ UInt8* overflow_flag) {
using FromFieldType = typename FromDataType::FieldType;
using ToNativeType = typename ToDataType::FieldType::NativeType;
if constexpr (std::is_floating_point_v<FromFieldType>) {
if (!std::isfinite(value)) {
+ if (overflow_flag) {
+ *overflow_flag = 1;
+ }
VLOG_DEBUG << "Decimal convert overflow. Cannot convert infinity or NaN to decimal";
return value < 0 ? std::numeric_limits<ToNativeType>::min()
: std::numeric_limits<ToNativeType>::max();
@@ -395,10 +413,16 @@ convert_to_decimal(const typename FromDataType::FieldType& value, UInt32 scale)
FromFieldType out;
out = value * ToDataType::get_scale_multiplier(scale);
if (out <= static_cast<FromFieldType>(std::numeric_limits<ToNativeType>::min())) {
+ if (overflow_flag) {
+ *overflow_flag = 1;
+ }
VLOG_DEBUG << "Decimal convert overflow. Float is out of Decimal range";
return std::numeric_limits<ToNativeType>::min();
}
if (out >= static_cast<FromFieldType>(std::numeric_limits<ToNativeType>::max())) {
+ if (overflow_flag) {
+ *overflow_flag = 1;
+ }
VLOG_DEBUG << "Decimal convert overflow. Float is out of Decimal range";
return std::numeric_limits<ToNativeType>::max();
}
diff --git a/be/src/vec/exprs/vexpr_context.cpp b/be/src/vec/exprs/vexpr_context.cpp
index ccb1045cb1..9033245202 100644
--- a/be/src/vec/exprs/vexpr_context.cpp
+++ b/be/src/vec/exprs/vexpr_context.cpp
@@ -110,6 +110,8 @@ int VExprContext::register_func(RuntimeState* state, const FunctionContext::Type
int varargs_buffer_size) {
_fn_contexts.push_back(FunctionContextImpl::create_context(
state, _pool.get(), return_type, arg_types, varargs_buffer_size, false));
+ _fn_contexts.back()->impl()->set_check_overflow_for_decimal(
+ state->check_overflow_for_decimal());
return _fn_contexts.size() - 1;
}
diff --git a/be/src/vec/functions/function_binary_arithmetic.h b/be/src/vec/functions/function_binary_arithmetic.h
index 2a8da748e3..d1c0375a55 100644
--- a/be/src/vec/functions/function_binary_arithmetic.h
+++ b/be/src/vec/functions/function_binary_arithmetic.h
@@ -23,6 +23,7 @@
#include <type_traits>
#include "runtime/decimalv2_value.h"
+#include "udf/udf_internal.h"
#include "vec/columns/column_const.h"
#include "vec/columns/column_decimal.h"
#include "vec/columns/column_nullable.h"
@@ -216,7 +217,8 @@ struct BinaryOperationImpl {
/// * no agrs scale. ScaleR = Scale1 + Scale2;
/// / first arg scale. ScaleR = Scale1 (scale_a = DecimalType<B>::get_scale()).
template <typename A, typename B, template <typename, typename> typename Operation,
- typename ResultType, bool is_to_null_type, bool check_overflow = true>
+ typename ResultType, bool is_to_null_type, bool return_nullable_type,
+ bool check_overflow = true>
struct DecimalBinaryOperation {
using OpTraits = OperationTraits<Operation>;
@@ -249,12 +251,14 @@ struct DecimalBinaryOperation {
for (size_t i = 0; i < size; ++i) {
c[i] = apply(a[i], b[i], null_map[i]);
}
- } else {
- if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
- for (size_t i = 0; i < size; ++i) {
- c[i] = apply_scaled_div(a[i], b[i], null_map[i]);
- }
- return;
+ } else if constexpr (OpTraits::is_division && (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
+ for (size_t i = 0; i < size; ++i) {
+ c[i] = apply_scaled_div(a[i], b[i], null_map[i]);
+ }
+ } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
+ (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
+ for (size_t i = 0; i < size; ++i) {
+ null_map[i] = apply_op_safely(a[i], b[i], c[i].value);
}
}
}
@@ -281,21 +285,21 @@ struct DecimalBinaryOperation {
for (size_t i = 0; i < size; ++i) {
c[i] = apply_scaled_div(a[i], b, null_map[i]);
}
- return;
- }
-
- for (size_t i = 0; i < size; ++i) {
- c[i] = apply(a[i], b, null_map[i]);
+ } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
+ (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
+ for (size_t i = 0; i < size; ++i) {
+ null_map[i] = apply_op_safely(a[i], b, c[i].value);
+ }
+ } else {
+ for (size_t i = 0; i < size; ++i) {
+ c[i] = apply(a[i], b, null_map[i]);
+ }
}
}
static void constant_vector(A a, const typename Traits::ArrayB& b, ArrayC& c) {
size_t size = b.size();
- if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
- for (size_t i = 0; i < size; ++i) {
- c[i] = apply_scaled_div(a, b[i]);
- }
- } else if constexpr (IsDecimalV2<A> || IsDecimalV2<B>) {
+ if constexpr (IsDecimalV2<A> || IsDecimalV2<B>) {
DecimalV2Value da(a);
for (size_t i = 0; i < size; ++i) {
c[i] = Op::template apply(da, DecimalV2Value(b[i])).value();
@@ -314,33 +318,43 @@ struct DecimalBinaryOperation {
for (size_t i = 0; i < size; ++i) {
c[i] = apply_scaled_div(a, b[i], null_map[i]);
}
- return;
- }
-
- for (size_t i = 0; i < size; ++i) {
- c[i] = apply(a, b[i], null_map[i]);
+ } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
+ (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
+ for (size_t i = 0; i < size; ++i) {
+ null_map[i] = apply_op_safely(a, b[i], c[i].value);
+ }
+ } else {
+ for (size_t i = 0; i < size; ++i) {
+ c[i] = apply(a, b[i], null_map[i]);
+ }
}
}
- static ResultType constant_constant(A a, B b) {
- if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
- return apply_scaled_div(a, b);
- }
- return apply(a, b);
- }
+ static ResultType constant_constant(A a, B b) { return apply(a, b); }
static ResultType constant_constant(A a, B b, UInt8& is_null) {
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
return apply_scaled_div(a, b, is_null);
+ } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
+ (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
+ NativeResultType res;
+ is_null = apply_op_safely(a, b, res);
+ return res;
+ } else {
+ return apply(a, b, is_null);
}
- return apply(a, b, is_null);
}
static ColumnPtr adapt_decimal_constant_constant(A a, B b, DataTypePtr res_data_type) {
auto column_result = ColumnDecimal<ResultType>::create(
1, assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale());
- if constexpr (is_to_null_type) {
+ if constexpr (return_nullable_type && !is_to_null_type &&
+ ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
+ IsDecimalV2<B>)) {
+ LOG(FATAL) << "Invalid function type!";
+ return column_result;
+ } else if constexpr (return_nullable_type || is_to_null_type) {
auto null_map = ColumnUInt8::create(1, 0);
column_result->get_element(0) = constant_constant(a, b, null_map->get_element(0));
return ColumnNullable::create(std::move(column_result), std::move(null_map));
@@ -358,7 +372,12 @@ struct DecimalBinaryOperation {
assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale());
DCHECK(column_left_ptr != nullptr);
- if constexpr (is_to_null_type) {
+ if constexpr (return_nullable_type && !is_to_null_type &&
+ ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
+ IsDecimalV2<B>)) {
+ LOG(FATAL) << "Invalid function type!";
+ return column_result;
+ } else if constexpr (return_nullable_type || is_to_null_type) {
auto null_map = ColumnUInt8::create(column_left->size(), 0);
vector_constant(column_left_ptr->get_data(), b, column_result->get_data(),
null_map->get_data());
@@ -377,7 +396,12 @@ struct DecimalBinaryOperation {
assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale());
DCHECK(column_right_ptr != nullptr);
- if constexpr (is_to_null_type) {
+ if constexpr (return_nullable_type && !is_to_null_type &&
+ ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
+ IsDecimalV2<B>)) {
+ LOG(FATAL) << "Invalid function type!";
+ return column_result;
+ } else if constexpr (return_nullable_type || is_to_null_type) {
auto null_map = ColumnUInt8::create(column_right->size(), 0);
constant_vector(a, column_right_ptr->get_data(), column_result->get_data(),
null_map->get_data());
@@ -398,7 +422,12 @@ struct DecimalBinaryOperation {
assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale());
DCHECK(column_left_ptr != nullptr && column_right_ptr != nullptr);
- if constexpr (is_to_null_type) {
+ if constexpr (return_nullable_type && !is_to_null_type &&
+ ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
+ IsDecimalV2<B>)) {
+ LOG(FATAL) << "Invalid function type!";
+ return column_result;
+ } else if constexpr (return_nullable_type || is_to_null_type) {
auto null_map = ColumnUInt8::create(column_result->size(), 0);
vector_vector(column_left_ptr->get_data(), column_right_ptr->get_data(),
column_result->get_data(), null_map->get_data());
@@ -483,6 +512,12 @@ private:
UInt8& is_null) {
return apply(a, b, is_null);
}
+
+ static UInt8 apply_op_safely(NativeResultType a, NativeResultType b, NativeResultType& c) {
+ if constexpr (OpTraits::is_multiply || OpTraits::is_plus_minus) {
+ return Op::template apply(a, b, c);
+ }
+ }
};
/// Used to indicate undefined operation
@@ -568,7 +603,8 @@ struct BinaryOperationTraits {
};
template <typename LeftDataType, typename RightDataType, typename ExpectedResultDataType,
- template <typename, typename> class Operation, bool is_to_null_type>
+ template <typename, typename> class Operation, bool is_to_null_type,
+ bool return_nullable_type>
struct ConstOrVectorAdapter {
static constexpr bool result_is_decimal =
IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>;
@@ -580,7 +616,8 @@ struct ConstOrVectorAdapter {
using OperationImpl = std::conditional_t<
IsDataTypeDecimal<ResultDataType>,
- DecimalBinaryOperation<A, B, Operation, ResultType, is_to_null_type>,
+ DecimalBinaryOperation<A, B, Operation, ResultType, is_to_null_type,
+ return_nullable_type>,
BinaryOperationImpl<A, B, Operation<A, B>, is_to_null_type, ResultType>>;
static ColumnPtr execute(ColumnPtr column_left, ColumnPtr column_right,
@@ -774,6 +811,7 @@ public:
right_generic =
static_cast<const DataTypeNullable*>(right_generic)->get_nested_type().get();
}
+ bool result_is_nullable = context->impl()->check_overflow_for_decimal();
if (result_generic->is_nullable()) {
result_generic =
static_cast<const DataTypeNullable*>(result_generic)->get_nested_type().get();
@@ -795,15 +833,31 @@ public:
ResultDataType>)&&(IsDataTypeDecimal<ExpectedResultDataType> ==
(IsDataTypeDecimal<LeftDataType> ||
IsDataTypeDecimal<RightDataType>))) {
- auto column_result = ConstOrVectorAdapter<
- LeftDataType, RightDataType,
- std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>,
- ExpectedResultDataType, ResultDataType>,
- Operation, is_to_null_type>::
- execute(block.get_by_position(arguments[0]).column,
- block.get_by_position(arguments[1]).column, left, right,
- remove_nullable(block.get_by_position(result).type));
- block.replace_by_position(result, std::move(column_result));
+ if (result_is_nullable) {
+ auto column_result = ConstOrVectorAdapter<
+ LeftDataType, RightDataType,
+ std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>,
+ ExpectedResultDataType, ResultDataType>,
+ Operation, is_to_null_type,
+ true>::execute(block.get_by_position(arguments[0]).column,
+ block.get_by_position(arguments[1]).column, left,
+ right,
+ remove_nullable(
+ block.get_by_position(result).type));
+ block.replace_by_position(result, std::move(column_result));
+ } else {
+ auto column_result = ConstOrVectorAdapter<
+ LeftDataType, RightDataType,
+ std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>,
+ ExpectedResultDataType, ResultDataType>,
+ Operation, is_to_null_type,
+ false>::execute(block.get_by_position(arguments[0]).column,
+ block.get_by_position(arguments[1]).column,
+ left, right,
+ remove_nullable(
+ block.get_by_position(result).type));
+ block.replace_by_position(result, std::move(column_result));
+ }
return true;
}
return false;
diff --git a/be/src/vec/functions/function_cast.h b/be/src/vec/functions/function_cast.h
index 90edd3906b..e3baaecdd2 100644
--- a/be/src/vec/functions/function_cast.h
+++ b/be/src/vec/functions/function_cast.h
@@ -22,6 +22,7 @@
#include <fmt/format.h>
+#include "udf/udf_internal.h"
#include "vec/columns/column_array.h"
#include "vec/columns/column_const.h"
#include "vec/columns/column_nullable.h"
@@ -72,7 +73,7 @@ struct ConvertImpl {
template <typename Additions = void*>
static Status execute(Block& block, const ColumnNumbers& arguments, size_t result,
- size_t /*input_rows_count*/,
+ size_t /*input_rows_count*/, bool check_overflow [[maybe_unused]] = false,
Additions additions [[maybe_unused]] = Additions()) {
const ColumnWithTypeAndName& named_from = block.get_by_position(arguments[0]);
@@ -96,37 +97,50 @@ struct ConvertImpl {
if constexpr (IsDataTypeDecimal<ToDataType>) {
UInt32 scale = additions;
col_to = ColVecTo::create(0, scale);
- } else
+ } else {
col_to = ColVecTo::create();
+ }
const auto& vec_from = col_from->get_data();
auto& vec_to = col_to->get_data();
size_t size = vec_from.size();
vec_to.resize(size);
- for (size_t i = 0; i < size; ++i) {
- if constexpr (IsDataTypeDecimal<FromDataType> || IsDataTypeDecimal<ToDataType>) {
- if constexpr (IsDataTypeDecimal<FromDataType> && IsDataTypeDecimal<ToDataType>)
+ if constexpr (IsDataTypeDecimal<FromDataType> || IsDataTypeDecimal<ToDataType>) {
+ ColumnUInt8::MutablePtr col_null_map_to = nullptr;
+ UInt8* vec_null_map_to = nullptr;
+ if (check_overflow) {
+ col_null_map_to = ColumnUInt8::create(size, 0);
+ vec_null_map_to = col_null_map_to->get_data().data();
+ }
+ for (size_t i = 0; i < size; ++i) {
+ if constexpr (IsDataTypeDecimal<FromDataType> &&
+ IsDataTypeDecimal<ToDataType>) {
vec_to[i] = convert_decimals<FromDataType, ToDataType>(
- vec_from[i], vec_from.get_scale(), vec_to.get_scale());
- else if constexpr (IsDataTypeDecimal<FromDataType> &&
- IsDataTypeNumber<ToDataType>)
+ vec_from[i], vec_from.get_scale(), vec_to.get_scale(),
+ vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to);
+ } else if constexpr (IsDataTypeDecimal<FromDataType> &&
+ IsDataTypeNumber<ToDataType>) {
vec_to[i] = convert_from_decimal<FromDataType, ToDataType>(
vec_from[i], vec_from.get_scale());
- else if constexpr (IsDataTypeNumber<FromDataType> &&
- IsDataTypeDecimal<ToDataType>)
+ } else if constexpr (IsDataTypeNumber<FromDataType> &&
+ IsDataTypeDecimal<ToDataType>) {
vec_to[i] = convert_to_decimal<FromDataType, ToDataType>(
- vec_from[i], vec_to.get_scale());
- else if constexpr (IsTimeType<FromDataType> && IsDataTypeDecimal<ToDataType>) {
+ vec_from[i], vec_to.get_scale(),
+ vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to);
+ } else if constexpr (IsTimeType<FromDataType> &&
+ IsDataTypeDecimal<ToDataType>) {
vec_to[i] = convert_to_decimal<DataTypeInt64, ToDataType>(
reinterpret_cast<const VecDateTimeValue&>(vec_from[i]).to_int64(),
- vec_to.get_scale());
+ vec_to.get_scale(),
+ vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to);
} else if constexpr (IsDateV2Type<FromDataType> &&
IsDataTypeDecimal<ToDataType>) {
vec_to[i] = convert_to_decimal<DataTypeUInt32, ToDataType>(
reinterpret_cast<const DateV2Value<DateV2ValueType>&>(vec_from[i])
.to_date_int_val(),
- vec_to.get_scale());
+ vec_to.get_scale(),
+ vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to);
} else if constexpr (IsDateTimeV2Type<FromDataType> &&
IsDataTypeDecimal<ToDataType>) {
// TODO: should we consider the scale of datetimev2?
@@ -134,9 +148,21 @@ struct ConvertImpl {
reinterpret_cast<const DateV2Value<DateTimeV2ValueType>&>(
vec_from[i])
.to_date_int_val(),
- vec_to.get_scale());
+ vec_to.get_scale(),
+ vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to);
}
- } else if constexpr (IsTimeType<FromDataType>) {
+ }
+ if (check_overflow) {
+ block.replace_by_position(
+ result,
+ ColumnNullable::create(std::move(col_to), std::move(col_null_map_to)));
+ } else {
+ block.replace_by_position(result, std::move(col_to));
+ }
+
+ return Status::OK();
+ } else if constexpr (IsTimeType<FromDataType>) {
+ for (size_t i = 0; i < size; ++i) {
if constexpr (IsTimeType<ToDataType>) {
vec_to[i] = static_cast<ToFieldType>(vec_from[i]);
if constexpr (IsDateTimeType<ToDataType>) {
@@ -152,7 +178,9 @@ struct ConvertImpl {
vec_to[i] =
reinterpret_cast<const VecDateTimeValue&>(vec_from[i]).to_int64();
}
- } else if constexpr (IsTimeV2Type<FromDataType>) {
+ }
+ } else if constexpr (IsTimeV2Type<FromDataType>) {
+ for (size_t i = 0; i < size; ++i) {
if constexpr (IsTimeV2Type<ToDataType>) {
if constexpr (IsDateTimeV2Type<ToDataType> && IsDateV2Type<FromDataType>) {
DataTypeDateV2::cast_to_date_time_v2(vec_from[i], vec_to[i]);
@@ -189,7 +217,9 @@ struct ConvertImpl {
.to_int64();
}
}
- } else {
+ }
+ } else {
+ for (size_t i = 0; i < size; ++i) {
vec_to[i] = static_cast<ToFieldType>(vec_from[i]);
}
}
@@ -547,7 +577,7 @@ struct ConvertImpl<DataTypeString, ToDataType, Name> {
template <typename Additions = void*>
static Status execute(Block& block, const ColumnNumbers& arguments, size_t result,
- size_t /*input_rows_count*/,
+ size_t /*input_rows_count*/, bool check_overflow [[maybe_unused]] = false,
Additions additions [[maybe_unused]] = Additions()) {
return Status::RuntimeError("not support convert from string");
}
@@ -832,19 +862,6 @@ public:
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) override {
- return executeInternal(block, arguments, result, input_rows_count);
- }
-
- bool has_information_about_monotonicity() const override { return Monotonic::has(); }
-
- Monotonicity get_monotonicity_for_range(const IDataType& type, const Field& left,
- const Field& right) const override {
- return Monotonic::get(type, left, right);
- }
-
-private:
- Status executeInternal(Block& block, const ColumnNumbers& arguments, size_t result,
- size_t input_rows_count) {
if (!arguments.size()) {
return Status::RuntimeError("Function {} expects at least 1 arguments", get_name());
}
@@ -873,13 +890,15 @@ private:
UInt32 scale = extract_to_decimal_scale(scale_column);
ret_status = ConvertImpl<LeftDataType, RightDataType, Name>::execute(
- block, arguments, result, input_rows_count, scale);
+ block, arguments, result, input_rows_count,
+ context->impl()->check_overflow_for_decimal(), scale);
} else if constexpr (IsDataTypeDateTimeV2<RightDataType>) {
const ColumnWithTypeAndName& scale_column = block.get_by_position(result);
auto type =
check_and_get_data_type<DataTypeDateTimeV2>(scale_column.type.get());
ret_status = ConvertImpl<LeftDataType, RightDataType, Name>::execute(
- block, arguments, result, input_rows_count, type->get_scale());
+ block, arguments, result, input_rows_count,
+ context->impl()->check_overflow_for_decimal(), type->get_scale());
} else {
ret_status = ConvertImpl<LeftDataType, RightDataType, Name>::execute(
block, arguments, result, input_rows_count);
@@ -896,6 +915,13 @@ private:
return ret_status;
}
}
+
+ bool has_information_about_monotonicity() const override { return Monotonic::has(); }
+
+ Monotonicity get_monotonicity_for_range(const IDataType& type, const Field& left,
+ const Field& right) const override {
+ return Monotonic::get(type, left, right);
+ }
};
using FunctionToUInt8 = FunctionConvert<DataTypeUInt8, NameToUInt8, ToNumberMonotonicity<UInt8>>;
@@ -1055,7 +1081,7 @@ struct ConvertThroughParsing {
template <typename Additions = void*>
static Status execute(Block& block, const ColumnNumbers& arguments, size_t result,
- size_t input_rows_count,
+ size_t input_rows_count, bool check_overflow [[maybe_unused]] = false,
Additions additions [[maybe_unused]] = Additions()) {
using ColVecTo = std::conditional_t<IsDecimalNumber<ToFieldType>,
ColumnDecimal<ToFieldType>, ColumnVector<ToFieldType>>;
@@ -1254,7 +1280,8 @@ public:
const ColumnNumbers& /*arguments*/,
size_t /*result*/) const override {
return std::make_shared<PreparedFunctionCast>(
- prepare_unpack_dictionaries(get_argument_types()[0], get_return_type()), name);
+ prepare_unpack_dictionaries(context, get_argument_types()[0], get_return_type()),
+ name);
}
String get_name() const override { return name; }
@@ -1347,7 +1374,8 @@ private:
using RightDataType = typename Types::RightType;
ConvertImpl<LeftDataType, RightDataType, NameCast>::execute(
- block, arguments, result, input_rows_count, scale);
+ block, arguments, result, input_rows_count,
+ context->impl()->check_overflow_for_decimal(), scale);
return true;
});
@@ -1396,7 +1424,7 @@ private:
return create_unsupport_wrapper(error_msg);
}
- WrapperType create_array_wrapper(const DataTypePtr& from_type_untyped,
+ WrapperType create_array_wrapper(FunctionContext* context, const DataTypePtr& from_type_untyped,
const DataTypeArray& to_type) const {
/// Conversion from String through parsing.
if (check_and_get_data_type<DataTypeString>(from_type_untyped.get())) {
@@ -1425,7 +1453,8 @@ private:
const DataTypePtr& to_nested_type = to_type.get_nested_type();
/// Prepare nested type conversion
- const auto nested_function = prepare_unpack_dictionaries(from_nested_type, to_nested_type);
+ const auto nested_function =
+ prepare_unpack_dictionaries(context, from_nested_type, to_nested_type);
return [nested_function, from_nested_type, to_nested_type](
FunctionContext* context, Block& block, const ColumnNumbers& arguments,
@@ -1513,7 +1542,7 @@ private:
}
}
- WrapperType prepare_unpack_dictionaries(const DataTypePtr& from_type,
+ WrapperType prepare_unpack_dictionaries(FunctionContext* context, const DataTypePtr& from_type,
const DataTypePtr& to_type) const {
const auto& from_nested = from_type;
const auto& to_nested = to_type;
@@ -1534,18 +1563,20 @@ private:
constexpr bool skip_not_null_check = false;
- auto wrapper = prepare_remove_nullable(from_nested, to_nested, skip_not_null_check);
+ auto wrapper =
+ prepare_remove_nullable(context, from_nested, to_nested, skip_not_null_check);
return wrapper;
}
- WrapperType prepare_remove_nullable(const DataTypePtr& from_type, const DataTypePtr& to_type,
+ WrapperType prepare_remove_nullable(FunctionContext* context, const DataTypePtr& from_type,
+ const DataTypePtr& to_type,
bool skip_not_null_check) const {
/// Determine whether pre-processing and/or post-processing must take place during conversion.
bool source_is_nullable = from_type->is_nullable();
bool result_is_nullable = to_type->is_nullable();
- auto wrapper = prepare_impl(remove_nullable(from_type), remove_nullable(to_type),
+ auto wrapper = prepare_impl(context, remove_nullable(from_type), remove_nullable(to_type),
result_is_nullable);
if (result_is_nullable) {
@@ -1620,8 +1651,8 @@ private:
/// 'from_type' and 'to_type' are nested types in case of Nullable.
/// 'requested_result_is_nullable' is true if CAST to Nullable type is requested.
- WrapperType prepare_impl(const DataTypePtr& from_type, const DataTypePtr& to_type,
- bool requested_result_is_nullable) const {
+ WrapperType prepare_impl(FunctionContext* context, const DataTypePtr& from_type,
+ const DataTypePtr& to_type, bool requested_result_is_nullable) const {
if (from_type->equals(*to_type))
return create_identity_wrapper(from_type);
else if (WhichDataType(from_type).is_nothing())
@@ -1679,7 +1710,8 @@ private:
case TypeIndex::String:
return create_string_wrapper(from_type);
case TypeIndex::Array:
- return create_array_wrapper(from_type, static_cast<const DataTypeArray&>(*to_type));
+ return create_array_wrapper(context, from_type,
+ static_cast<const DataTypeArray&>(*to_type));
default:
break;
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java
index e5502fe635..99b6e39e3b 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java
@@ -21,6 +21,7 @@
package org.apache.doris.analysis;
import org.apache.doris.catalog.Function;
+import org.apache.doris.catalog.Function.NullableMode;
import org.apache.doris.catalog.FunctionSet;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.ScalarFunction;
@@ -107,12 +108,13 @@ public class ArithmeticExpr extends Expr {
public static void initBuiltins(FunctionSet functionSet) {
for (Type t : Type.getNumericTypes()) {
+ NullableMode mode = t.isDecimalV3() ? NullableMode.CUSTOM : NullableMode.DEPEND_ON_ARGUMENT;
functionSet.addBuiltin(ScalarFunction.createBuiltinOperator(
- Operator.MULTIPLY.getName(), Lists.newArrayList(t, t), t));
+ Operator.MULTIPLY.getName(), Lists.newArrayList(t, t), t, mode));
functionSet.addBuiltin(ScalarFunction.createBuiltinOperator(
- Operator.ADD.getName(), Lists.newArrayList(t, t), t));
+ Operator.ADD.getName(), Lists.newArrayList(t, t), t, mode));
functionSet.addBuiltin(ScalarFunction.createBuiltinOperator(
- Operator.SUBTRACT.getName(), Lists.newArrayList(t, t), t));
+ Operator.SUBTRACT.getName(), Lists.newArrayList(t, t), t, mode));
}
functionSet.addBuiltin(ScalarFunction.createBuiltinOperator(
Operator.DIVIDE.getName(),
@@ -173,15 +175,14 @@ public class ArithmeticExpr extends Expr {
for (int j = 0; j < Type.getNumericTypes().size(); j++) {
Type t2 = Type.getNumericTypes().get(j);
+ Type retType = Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false));
+ NullableMode mode = retType.isDecimalV3() ? NullableMode.CUSTOM : NullableMode.DEPEND_ON_ARGUMENT;
functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator(
- Operator.MULTIPLY.getName(), Lists.newArrayList(t1, t2),
- Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false))));
+ Operator.MULTIPLY.getName(), Lists.newArrayList(t1, t2), retType, mode));
functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator(
- Operator.ADD.getName(), Lists.newArrayList(t1, t2),
- Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false))));
+ Operator.ADD.getName(), Lists.newArrayList(t1, t2), retType, mode));
functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator(
- Operator.SUBTRACT.getName(), Lists.newArrayList(t1, t2),
- Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false))));
+ Operator.SUBTRACT.getName(), Lists.newArrayList(t1, t2), retType, mode));
}
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
index 79df8f3395..9c29c7ffee 100755
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
@@ -20,6 +20,7 @@
package org.apache.doris.analysis;
+import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Function;
import org.apache.doris.catalog.FunctionSet;
@@ -31,6 +32,7 @@ import org.apache.doris.common.Config;
import org.apache.doris.common.TreeNode;
import org.apache.doris.common.io.Writable;
import org.apache.doris.common.util.VectorizedUtil;
+import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.ExprStats;
import org.apache.doris.thrift.TExpr;
import org.apache.doris.thrift.TExprNode;
@@ -2036,6 +2038,25 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
if (fn.functionName().equalsIgnoreCase("concat_ws")) {
return children.get(0).isNullable();
}
+ if (fn.functionName().equalsIgnoreCase(Operator.MULTIPLY.getName())
+ && fn.getReturnType().isDecimalV3()) {
+ if (ConnectContext.get() != null
+ && ConnectContext.get().getSessionVariable().checkOverflowForDecimal()) {
+ return true;
+ } else {
+ return hasNullableChild();
+ }
+ }
+ if ((fn.functionName().equalsIgnoreCase(Operator.ADD.getName())
+ || fn.functionName().equalsIgnoreCase(Operator.SUBTRACT.getName()))
+ && fn.getReturnType().isDecimalV3()) {
+ if (ConnectContext.get() != null
+ && ConnectContext.get().getSessionVariable().checkOverflowForDecimal()) {
+ return true;
+ } else {
+ return hasNullableChild();
+ }
+ }
return true;
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
index 7faa492295..dcea629834 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
@@ -189,6 +189,8 @@ public class SessionVariable implements Serializable, Writable {
public static final String ENABLE_PROJECTION = "enable_projection";
+ public static final String CHECK_OVERFLOW_FOR_DECIMAL = "check_overflow_for_decimal";
+
public static final String TRIM_TAILING_SPACES_FOR_EXTERNAL_TABLE_QUERY
= "trim_tailing_spaces_for_external_table_query";
@@ -542,6 +544,9 @@ public class SessionVariable implements Serializable, Writable {
@VariableMgr.VarAttr(name = ENABLE_PROJECTION)
private boolean enableProjection = true;
+ @VariableMgr.VarAttr(name = CHECK_OVERFLOW_FOR_DECIMAL)
+ private boolean checkOverflowForDecimal = false;
+
/**
* as the new optimizer is not mature yet, use this var
* to control whether to use new optimizer, remove it when
@@ -1235,6 +1240,10 @@ public class SessionVariable implements Serializable, Writable {
return enableProjection;
}
+ public boolean checkOverflowForDecimal() {
+ return checkOverflowForDecimal;
+ }
+
public boolean isTrimTailingSpacesForExternalTableQuery() {
return trimTailingSpacesForExternalTableQuery;
}
@@ -1368,6 +1377,7 @@ public class SessionVariable implements Serializable, Writable {
}
tResult.setEnableFunctionPushdown(enableFunctionPushdown);
+ tResult.setCheckOverflowForDecimal(checkOverflowForDecimal);
tResult.setFragmentTransmissionCompressionCodec(fragmentTransmissionCompressionCodec);
tResult.setEnableLocalExchange(enableLocalExchange);
tResult.setEnableNewShuffleHashMethod(enableNewShuffleHashMethod);
diff --git a/gensrc/thrift/PaloInternalService.thrift b/gensrc/thrift/PaloInternalService.thrift
index 32d6721bb3..bb241c4aa0 100644
--- a/gensrc/thrift/PaloInternalService.thrift
+++ b/gensrc/thrift/PaloInternalService.thrift
@@ -187,6 +187,7 @@ struct TQueryOptions {
55: optional bool enable_pipeline_engine = false
56: optional i32 repeat_max_num = 0
+ 57: optional bool check_overflow_for_decimal = false
}
diff --git a/regression-test/data/datatype_p0/decimalv3/test_overflow.out b/regression-test/data/datatype_p0/decimalv3/test_overflow.out
new file mode 100644
index 0000000000..c9b9873cd7
--- /dev/null
+++ b/regression-test/data/datatype_p0/decimalv3/test_overflow.out
@@ -0,0 +1,19 @@
+-- This file is automatically generated. You should know what you did if you want to edit this
+-- !select_all --
+11111111111111111111.100000000000000000 11111111111111111111.200000000000000000 11111111111111111111.300000000000000000 1.1000000000000000000000000000000000000 1.2000000000000000000000000000000000000 1.3000000000000000000000000000000000000 9
+
+-- !select_check_overflow1 --
+\N \N \N 99999999999999999999.900000000000000000 \N
+
+-- !select_check_overflow2 --
+1.1000000000000000000000000000000000000 111111111111111111111.000000000000000000 \N
+
+-- !select_check_overflow3 --
+11111111111111111111.100000000000000000 \N
+
+-- !select_not_check_overflow1 --
+99.999999999999999999999999999999999999 99.999999999999999999999999999999999999 1.1111111111111111E21 99999999999999999999.900000000000000000 99999999999999999999.999999999999999999
+
+-- !select_not_check_overflow2 --
+1.1000000000000000000000000000000000000 111111111111111111111.000000000000000000 -15.9141183460469231731687303715884105728
+
diff --git a/regression-test/suites/datatype_p0/decimalv3/test_overflow.groovy b/regression-test/suites/datatype_p0/decimalv3/test_overflow.groovy
new file mode 100644
index 0000000000..01de2ea498
--- /dev/null
+++ b/regression-test/suites/datatype_p0/decimalv3/test_overflow.groovy
@@ -0,0 +1,56 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_overflow") {
+
+ def table1 = "test_overflow"
+
+ sql "drop table if exists ${table1}"
+
+ sql """
+ CREATE TABLE IF NOT EXISTS test_overflow (
+ `k1` decimalv3(38, 18) NULL COMMENT "",
+ `k2` decimalv3(38, 18) NULL COMMENT "",
+ `k3` decimalv3(38, 18) NULL COMMENT "",
+ `v1` decimalv3(38, 37) NULL COMMENT "",
+ `v2` decimalv3(38, 37) NULL COMMENT "",
+ `v3` decimalv3(38, 37) NULL COMMENT "",
+ `v4` INT NULL COMMENT ""
+ ) ENGINE=OLAP
+ COMMENT "OLAP"
+ DISTRIBUTED BY HASH(`k1`, `k2`, `k3`) BUCKETS 8
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1",
+ "in_memory" = "false",
+ "storage_format" = "V2"
+ )
+ """
+
+ sql """insert into test_overflow values(11111111111111111111.1,11111111111111111111.2,11111111111111111111.3, 1.1,1.2,1.3,9)
+ """
+ qt_select_all "select * from test_overflow order by k1"
+
+ sql " SET check_overflow_for_decimal = true; "
+ qt_select_check_overflow1 "select k1 * k2, k1 * k3, k1 * k2 * k3, k1 * v4, k1*50 from test_overflow;"
+ qt_select_check_overflow2 "select v1, k1*10, v1 +k1*10 from test_overflow"
+ qt_select_check_overflow3 "select `k1`, cast (`k1` as DECIMALV3(38, 36)) from test_overflow;"
+
+ sql " SET check_overflow_for_decimal = false; "
+ qt_select_not_check_overflow1 "select k1 * k2, k1 * k3, k1 * k2 * k3, k1 * v4, k1*50 from test_overflow;"
+ qt_select_not_check_overflow2 "select v1, k1*10, v1 +k1*10 from test_overflow"
+ sql "drop table if exists ${table1}"
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org