You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2023/01/28 06:06:30 UTC
[doris] branch branch-1.2-lts updated: [Decimalv3/DateV2](cherrypick) pick related commits from master to 1.2 (#16059)
This is an automated email from the ASF dual-hosted git repository.
morningman pushed a commit to branch branch-1.2-lts
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-1.2-lts by this push:
new d2a619706b [Decimalv3/DateV2](cherrypick) pick related commits from master to 1.2 (#16059)
d2a619706b is described below
commit d2a619706bc4435f2e2c00877cf874d10731ab51
Author: Gabriel <ga...@gmail.com>
AuthorDate: Sat Jan 28 14:06:24 2023 +0800
[Decimalv3/DateV2](cherrypick) pick related commits from master to 1.2 (#16059)
#15463
#15505
#15574
#15644
#15696
#15674
#15930
#16050
#15312
#14814
#15860
#15885
#15915
#15985
#16053
---
be/src/runtime/collection_value.cpp | 50 ++
be/src/runtime/runtime_state.h | 5 +
be/src/udf/udf_internal.h | 8 +
be/src/util/date_func.cpp | 21 +
be/src/util/date_func.h | 2 +
be/src/util/string_parser.hpp | 2 +-
be/src/vec/CMakeLists.txt | 1 +
be/src/vec/core/block.h | 4 +-
be/src/vec/data_types/data_type_decimal.h | 30 +-
be/src/vec/data_types/data_type_factory.cpp | 4 +
be/src/vec/data_types/data_type_time.cpp | 52 +++
be/src/vec/data_types/data_type_time.h | 53 +++
be/src/vec/exec/scan/vscan_node.cpp | 8 +-
be/src/vec/exprs/vexpr_context.cpp | 2 +
be/src/vec/functions/function_binary_arithmetic.h | 142 ++++--
be/src/vec/functions/function_cast.h | 124 +++--
.../function_date_or_datetime_computation.cpp | 99 ++--
.../function_date_or_datetime_computation.h | 227 ++++-----
.../function_date_or_datetime_computation_v2.cpp | 174 +++----
be/src/vec/functions/function_running_difference.h | 3 +-
be/src/vec/functions/function_string.h | 97 +++-
be/src/vec/runtime/vdatetime_value.cpp | 1 +
be/src/vec/sink/vmysql_result_writer.cpp | 15 +-
.../function/function_running_difference_test.cpp | 2 +-
be/test/vec/function/function_test_util.h | 7 +-
be/test/vec/function/function_time_test.cpp | 22 +-
.../java/org/apache/doris/catalog/ScalarType.java | 8 +-
.../main/java/org/apache/doris/catalog/Type.java | 23 +-
.../org/apache/doris/analysis/ArithmeticExpr.java | 25 +-
.../org/apache/doris/analysis/BinaryPredicate.java | 8 +-
.../org/apache/doris/analysis/DateLiteral.java | 11 +
.../main/java/org/apache/doris/analysis/Expr.java | 21 +
.../org/apache/doris/analysis/FloatLiteral.java | 4 +-
.../apache/doris/analysis/FunctionCallExpr.java | 5 +-
.../apache/doris/analysis/SetOperationStmt.java | 6 +
.../org/apache/doris/analysis/StringLiteral.java | 7 +-
.../doris/analysis/TimestampArithmeticExpr.java | 7 +-
.../java/org/apache/doris/planner/ScanNode.java | 5 +-
.../java/org/apache/doris/qe/SessionVariable.java | 10 +
.../RoundLiteralInBinaryPredicatesRule.java | 5 +-
.../analysis/CreateTableAsSelectStmtTest.java | 20 +-
.../org/apache/doris/analysis/QueryStmtTest.java | 6 +-
.../org/apache/doris/analysis/SelectStmtTest.java | 4 +-
.../org/apache/doris/planner/QueryPlanTest.java | 40 +-
.../doris/rewrite/RewriteDateLiteralRuleTest.java | 37 +-
.../udf/{UdfExecutor.java => BaseExecutor.java} | 513 ++++++++-------------
.../java/org/apache/doris/udf/UdafExecutor.java | 344 ++------------
.../java/org/apache/doris/udf/UdfExecutor.java | 349 ++------------
.../main/java/org/apache/doris/udf/UdfUtils.java | 43 +-
gensrc/script/doris_builtins_functions.py | 177 ++++---
gensrc/thrift/PaloInternalService.thrift | 2 +
.../data/correctness_p0/test_pushdown_constant.out | 6 +
.../storage/test_dup_tab_datetime_nullable.out | 1 -
.../decimalv3/test_load.out} | 6 +-
.../data/datatype_p0/decimalv3/test_overflow.out | 19 +
.../data/datatype_p0/decimalv3/test_predicate.out | 13 +
.../array_functions/test_array_functions.out | 12 +
.../math_functions/test_running_difference.out | 12 +-
.../correctness_p0/test_pushdown_constant.groovy | 31 +-
.../datatype_p0/decimalv3/test_data/test.csv | 3 +
.../suites/datatype_p0/decimalv3/test_load.groovy | 58 +++
.../datatype_p0/decimalv3/test_overflow.groovy | 56 +++
.../datatype_p0/decimalv3/test_predicate.groovy | 47 ++
.../array_functions/test_array_functions.groovy | 27 ++
.../datetime_functions/test_date_function.groovy | 2 +-
65 files changed, 1555 insertions(+), 1573 deletions(-)
diff --git a/be/src/runtime/collection_value.cpp b/be/src/runtime/collection_value.cpp
index 13185e1cba..8192e35dc7 100644
--- a/be/src/runtime/collection_value.cpp
+++ b/be/src/runtime/collection_value.cpp
@@ -115,12 +115,42 @@ struct CollectionValueSubTypeTrait<TYPE_DATETIME> {
using AnyValType = DateTimeVal;
};
+template <>
+struct CollectionValueSubTypeTrait<TYPE_DATEV2> {
+ using CppType = uint32_t;
+ using AnyValType = IntVal;
+};
+
+template <>
+struct CollectionValueSubTypeTrait<TYPE_DATETIMEV2> {
+ using CppType = uint64_t;
+ using AnyValType = BigIntVal;
+};
+
template <>
struct CollectionValueSubTypeTrait<TYPE_DECIMALV2> {
using CppType = decimal12_t;
using AnyValType = DecimalV2Val;
};
+template <>
+struct CollectionValueSubTypeTrait<TYPE_DECIMAL32> {
+ using CppType = int32_t;
+ using AnyValType = IntVal;
+};
+
+template <>
+struct CollectionValueSubTypeTrait<TYPE_DECIMAL64> {
+ using CppType = int64_t;
+ using AnyValType = BigIntVal;
+};
+
+template <>
+struct CollectionValueSubTypeTrait<TYPE_DECIMAL128I> {
+ using CppType = int128_t;
+ using AnyValType = LargeIntVal;
+};
+
template <>
struct CollectionValueSubTypeTrait<TYPE_ARRAY> {
using CppType = CollectionValue;
@@ -352,12 +382,27 @@ ArrayIterator CollectionValue::internal_iterator(PrimitiveType child_type) const
case TYPE_DATETIME:
return ArrayIterator(const_cast<CollectionValue*>(this),
static_cast<ArrayIteratorFunctions<TYPE_DATETIME>*>(nullptr));
+ case TYPE_DATEV2:
+ return ArrayIterator(const_cast<CollectionValue*>(this),
+ static_cast<ArrayIteratorFunctions<TYPE_DATEV2>*>(nullptr));
+ case TYPE_DATETIMEV2:
+ return ArrayIterator(const_cast<CollectionValue*>(this),
+ static_cast<ArrayIteratorFunctions<TYPE_DATETIMEV2>*>(nullptr));
case TYPE_ARRAY:
return ArrayIterator(const_cast<CollectionValue*>(this),
static_cast<ArrayIteratorFunctions<TYPE_ARRAY>*>(nullptr));
case TYPE_DECIMALV2:
return ArrayIterator(const_cast<CollectionValue*>(this),
static_cast<ArrayIteratorFunctions<TYPE_DECIMALV2>*>(nullptr));
+ case TYPE_DECIMAL32:
+ return ArrayIterator(const_cast<CollectionValue*>(this),
+ static_cast<ArrayIteratorFunctions<TYPE_DECIMAL32>*>(nullptr));
+ case TYPE_DECIMAL64:
+ return ArrayIterator(const_cast<CollectionValue*>(this),
+ static_cast<ArrayIteratorFunctions<TYPE_DECIMAL64>*>(nullptr));
+ case TYPE_DECIMAL128I:
+ return ArrayIterator(const_cast<CollectionValue*>(this),
+ static_cast<ArrayIteratorFunctions<TYPE_DECIMAL128I>*>(nullptr));
default:
DCHECK(false) << "Invalid child type: " << child_type;
__builtin_unreachable();
@@ -389,8 +434,13 @@ Status type_check(PrimitiveType type) {
case TYPE_DATE:
case TYPE_DATETIME:
+ case TYPE_DATEV2:
+ case TYPE_DATETIMEV2:
case TYPE_DECIMALV2:
+ case TYPE_DECIMAL32:
+ case TYPE_DECIMAL64:
+ case TYPE_DECIMAL128I:
case TYPE_ARRAY:
break;
diff --git a/be/src/runtime/runtime_state.h b/be/src/runtime/runtime_state.h
index 7b8ae4d89d..f68de8d714 100644
--- a/be/src/runtime/runtime_state.h
+++ b/be/src/runtime/runtime_state.h
@@ -138,6 +138,11 @@ public:
_query_options.enable_function_pushdown;
}
+ bool check_overflow_for_decimal() const {
+ return _query_options.__isset.check_overflow_for_decimal &&
+ _query_options.check_overflow_for_decimal;
+ }
+
// Create a codegen object in _codegen. No-op if it has already been called.
// If codegen is enabled for the query, this is created when the runtime
// state is created. If codegen is disabled for the query, this is created
diff --git a/be/src/udf/udf_internal.h b/be/src/udf/udf_internal.h
index 1bc4fefd0b..67a8ec60e7 100644
--- a/be/src/udf/udf_internal.h
+++ b/be/src/udf/udf_internal.h
@@ -109,6 +109,12 @@ public:
const doris_udf::FunctionContext::TypeDesc& get_return_type() const { return _return_type; }
+ const bool check_overflow_for_decimal() const { return _check_overflow_for_decimal; }
+
+ bool set_check_overflow_for_decimal(bool check_overflow_for_decimal) {
+ return _check_overflow_for_decimal = check_overflow_for_decimal;
+ }
+
private:
friend class doris_udf::FunctionContext;
friend class ExprContext;
@@ -181,6 +187,8 @@ private:
// call that passes the correct AnyVal subclass pointer type.
std::vector<doris_udf::AnyVal*> _staging_input_vals;
+ bool _check_overflow_for_decimal = false;
+
// Indicates whether this context has been closed. Used for verification/debugging.
bool _closed;
diff --git a/be/src/util/date_func.cpp b/be/src/util/date_func.cpp
index 2c0ebdd811..e324d68106 100644
--- a/be/src/util/date_func.cpp
+++ b/be/src/util/date_func.cpp
@@ -109,4 +109,25 @@ int32_t time_to_buffer_from_double(double time, char* buffer) {
return buffer - begin;
}
+std::string time_to_buffer_from_double(double time) {
+ fmt::memory_buffer buffer;
+ if (time < 0) {
+ time = -time;
+ fmt::format_to(buffer, "-");
+ }
+ if (time > 3020399) {
+ time = 3020399;
+ }
+ int64_t hour = (int64_t)(time / 3600);
+ int32_t minute = ((int32_t)(time / 60)) % 60;
+ int32_t second = ((int32_t)time) % 60;
+ if (hour >= 100) {
+ fmt::format_to(buffer, fmt::format("{}", hour));
+ } else {
+ fmt::format_to(buffer, fmt::format("{:02d}", hour));
+ }
+ fmt::format_to(buffer, fmt::format(":{:02d}:{:02d}", minute, second));
+ return fmt::to_string(buffer);
+}
+
} // namespace doris
diff --git a/be/src/util/date_func.h b/be/src/util/date_func.h
index e5843b64af..4378fd32e1 100644
--- a/be/src/util/date_func.h
+++ b/be/src/util/date_func.h
@@ -32,4 +32,6 @@ int32_t time_to_buffer_from_double(double time, char* buffer);
uint32_t timestamp_from_date_v2(const std::string& date_str);
uint64_t timestamp_from_datetime_v2(const std::string& date_str);
+std::string time_to_buffer_from_double(double time);
+
} // namespace doris
diff --git a/be/src/util/string_parser.hpp b/be/src/util/string_parser.hpp
index 7562c22d06..653f0dac14 100644
--- a/be/src/util/string_parser.hpp
+++ b/be/src/util/string_parser.hpp
@@ -757,8 +757,8 @@ inline T StringParser::string_to_decimal(const char* s, int len, int type_precis
divisor = get_scale_multiplier<T>(shift);
}
if (LIKELY(divisor >= 0)) {
- value /= divisor;
T remainder = value % divisor;
+ value /= divisor;
if ((remainder > 0 ? T(remainder) : T(-remainder)) >= (divisor >> 1)) {
value += 1;
}
diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt
index 4ecb93ce0e..723f83758f 100644
--- a/be/src/vec/CMakeLists.txt
+++ b/be/src/vec/CMakeLists.txt
@@ -88,6 +88,7 @@ set(VEC_FILES
data_types/data_type_date_time.cpp
data_types/data_type_time_v2.cpp
data_types/data_type_jsonb.cpp
+ data_types/data_type_time.cpp
exec/vaggregation_node.cpp
exec/varrow_scanner.cpp
exec/vsort_node.cpp
diff --git a/be/src/vec/core/block.h b/be/src/vec/core/block.h
index db508f2b2f..af6f51c67c 100644
--- a/be/src/vec/core/block.h
+++ b/be/src/vec/core/block.h
@@ -467,7 +467,9 @@ public:
DCHECK_EQ(_columns.size(), block.columns());
for (int i = 0; i < _columns.size(); ++i) {
if (!_data_types[i]->equals(*block.get_by_position(i).type)) {
- DCHECK(_data_types[i]->is_nullable());
+ DCHECK(_data_types[i]->is_nullable())
+ << " target type: " << _data_types[i]->get_name()
+ << " src type: " << block.get_by_position(i).type->get_name();
DCHECK(((DataTypeNullable*)_data_types[i].get())
->get_nested_type()
->equals(*block.get_by_position(i).type));
diff --git a/be/src/vec/data_types/data_type_decimal.h b/be/src/vec/data_types/data_type_decimal.h
index ffc0564687..358fe79438 100644
--- a/be/src/vec/data_types/data_type_decimal.h
+++ b/be/src/vec/data_types/data_type_decimal.h
@@ -71,6 +71,11 @@ constexpr Int128 max_decimal_value<Decimal128>() {
return static_cast<int128_t>(999999999999999999ll) * 100000000000000000ll * 1000ll +
static_cast<int128_t>(99999999999999999ll) * 1000ll + 999ll;
}
+template <>
+constexpr Int128 max_decimal_value<Decimal128I>() {
+ return static_cast<int128_t>(999999999999999999ll) * 100000000000000000ll * 1000ll +
+ static_cast<int128_t>(99999999999999999ll) * 1000ll + 999ll;
+}
DataTypePtr create_decimal(UInt64 precision, UInt64 scale, bool use_v2);
@@ -290,8 +295,8 @@ constexpr bool IsDataTypeDecimalOrNumber =
template <typename FromDataType, typename ToDataType>
inline std::enable_if_t<IsDataTypeDecimal<FromDataType> && IsDataTypeDecimal<ToDataType>,
typename ToDataType::FieldType>
-convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_from,
- UInt32 scale_to) {
+convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_from, UInt32 scale_to,
+ UInt8* overflow_flag = nullptr) {
using FromFieldType = typename FromDataType::FieldType;
using ToFieldType = typename ToDataType::FieldType;
using MaxFieldType =
@@ -309,6 +314,9 @@ convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_fro
DataTypeDecimal<MaxFieldType>::get_scale_multiplier(scale_to - scale_from);
if (common::mul_overflow(static_cast<MaxNativeType>(value), converted_value,
converted_value)) {
+ if (overflow_flag) {
+ *overflow_flag = 1;
+ }
VLOG_DEBUG << "Decimal convert overflow";
return converted_value < 0
? std::numeric_limits<typename ToFieldType::NativeType>::min()
@@ -321,10 +329,16 @@ convert_decimals(const typename FromDataType::FieldType& value, UInt32 scale_fro
if constexpr (sizeof(FromFieldType) > sizeof(ToFieldType)) {
if (converted_value < std::numeric_limits<typename ToFieldType::NativeType>::min()) {
+ if (overflow_flag) {
+ *overflow_flag = 1;
+ }
VLOG_DEBUG << "Decimal convert overflow";
return std::numeric_limits<typename ToFieldType::NativeType>::min();
}
if (converted_value > std::numeric_limits<typename ToFieldType::NativeType>::max()) {
+ if (overflow_flag) {
+ *overflow_flag = 1;
+ }
VLOG_DEBUG << "Decimal convert overflow";
return std::numeric_limits<typename ToFieldType::NativeType>::max();
}
@@ -380,12 +394,16 @@ convert_from_decimal(const typename FromDataType::FieldType& value, UInt32 scale
template <typename FromDataType, typename ToDataType>
inline std::enable_if_t<IsDataTypeNumber<FromDataType> && IsDataTypeDecimal<ToDataType>,
typename ToDataType::FieldType>
-convert_to_decimal(const typename FromDataType::FieldType& value, UInt32 scale) {
+convert_to_decimal(const typename FromDataType::FieldType& value, UInt32 scale,
+ UInt8* overflow_flag) {
using FromFieldType = typename FromDataType::FieldType;
using ToNativeType = typename ToDataType::FieldType::NativeType;
if constexpr (std::is_floating_point_v<FromFieldType>) {
if (!std::isfinite(value)) {
+ if (overflow_flag) {
+ *overflow_flag = 1;
+ }
VLOG_DEBUG << "Decimal convert overflow. Cannot convert infinity or NaN to decimal";
return value < 0 ? std::numeric_limits<ToNativeType>::min()
: std::numeric_limits<ToNativeType>::max();
@@ -394,10 +412,16 @@ convert_to_decimal(const typename FromDataType::FieldType& value, UInt32 scale)
FromFieldType out;
out = value * ToDataType::get_scale_multiplier(scale);
if (out <= static_cast<FromFieldType>(std::numeric_limits<ToNativeType>::min())) {
+ if (overflow_flag) {
+ *overflow_flag = 1;
+ }
VLOG_DEBUG << "Decimal convert overflow. Float is out of Decimal range";
return std::numeric_limits<ToNativeType>::min();
}
if (out >= static_cast<FromFieldType>(std::numeric_limits<ToNativeType>::max())) {
+ if (overflow_flag) {
+ *overflow_flag = 1;
+ }
VLOG_DEBUG << "Decimal convert overflow. Float is out of Decimal range";
return std::numeric_limits<ToNativeType>::max();
}
diff --git a/be/src/vec/data_types/data_type_factory.cpp b/be/src/vec/data_types/data_type_factory.cpp
index 85dfc445ba..e622d979d5 100644
--- a/be/src/vec/data_types/data_type_factory.cpp
+++ b/be/src/vec/data_types/data_type_factory.cpp
@@ -20,6 +20,8 @@
#include "vec/data_types/data_type_factory.hpp"
+#include "data_type_time.h"
+
namespace doris::vectorized {
DataTypePtr DataTypeFactory::create_data_type(const doris::Field& col_desc) {
@@ -92,6 +94,8 @@ DataTypePtr DataTypeFactory::create_data_type(const TypeDescriptor& col_desc, bo
break;
case TYPE_TIME:
case TYPE_TIMEV2:
+ nested = std::make_shared<vectorized::DataTypeTime>();
+ break;
case TYPE_DOUBLE:
nested = std::make_shared<vectorized::DataTypeFloat64>();
break;
diff --git a/be/src/vec/data_types/data_type_time.cpp b/be/src/vec/data_types/data_type_time.cpp
new file mode 100644
index 0000000000..caa4e0530c
--- /dev/null
+++ b/be/src/vec/data_types/data_type_time.cpp
@@ -0,0 +1,52 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+// https://github.com/ClickHouse/ClickHouse/blob/master/src/DataTypes/DataTypeDateTime.cpp
+// and modified by Doris
+
+#include "vec/data_types/data_type_time.h"
+
+#include "util/date_func.h"
+#include "vec/columns/columns_number.h"
+
+namespace doris::vectorized {
+
+bool DataTypeTime::equals(const IDataType& rhs) const {
+ return typeid(rhs) == typeid(*this);
+}
+
+std::string DataTypeTime::to_string(const IColumn& column, size_t row_num) const {
+ Float64 float_val =
+ assert_cast<const ColumnFloat64&>(*column.convert_to_full_column_if_const().get())
+ .get_data()[row_num];
+ return time_to_buffer_from_double(float_val);
+}
+
+void DataTypeTime::to_string(const IColumn& column, size_t row_num, BufferWritable& ostr) const {
+ Float64 float_val =
+ assert_cast<const ColumnFloat64&>(*column.convert_to_full_column_if_const().get())
+ .get_data()[row_num];
+ std::string time_val = time_to_buffer_from_double(float_val);
+ // DateTime to_string the end is /0
+ ostr.write(time_val.data(), time_val.size());
+}
+
+MutableColumnPtr DataTypeTime::create_column() const {
+ return DataTypeNumberBase<Float64>::create_column();
+}
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/data_types/data_type_time.h b/be/src/vec/data_types/data_type_time.h
new file mode 100644
index 0000000000..b10c9d88d5
--- /dev/null
+++ b/be/src/vec/data_types/data_type_time.h
@@ -0,0 +1,53 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+// https://github.com/ClickHouse/ClickHouse/blob/master/src/DataTypes/DataTypeDateTime.h
+// and modified by Doris
+
+#pragma once
+
+#include "vec/data_types/data_type_number.h"
+#include "vec/data_types/data_type_number_base.h"
+
+namespace doris::vectorized {
+
+class DataTypeTime final : public DataTypeNumberBase<Float64> {
+public:
+ DataTypeTime() = default;
+
+ bool equals(const IDataType& rhs) const override;
+
+ std::string to_string(const IColumn& column, size_t row_num) const override;
+
+ void to_string(const IColumn& column, size_t row_num, BufferWritable& ostr) const override;
+
+ MutableColumnPtr create_column() const override;
+
+ bool can_be_used_as_version() const override { return true; }
+ bool is_summable() const override { return true; }
+ bool can_be_used_in_bit_operations() const override { return true; }
+ bool can_be_used_in_boolean_context() const override { return true; }
+ bool can_be_inside_nullable() const override { return true; }
+
+ bool can_be_promoted() const override { return true; }
+ DataTypePtr promote_numeric_type() const override {
+ using PromotedType = DataTypeNumber<NearestFieldType<Float64>>;
+ return std::make_shared<PromotedType>();
+ }
+};
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/exec/scan/vscan_node.cpp b/be/src/vec/exec/scan/vscan_node.cpp
index bbe0a75523..e0ba630f67 100644
--- a/be/src/vec/exec/scan/vscan_node.cpp
+++ b/be/src/vec/exec/scan/vscan_node.cpp
@@ -565,6 +565,11 @@ bool VScanNode::_is_predicate_acting_on_slot(
// the type of predicate not match the slot's type
return false;
}
+ } else if (child_contains_slot->type().is_datetime_type() &&
+ child_contains_slot->node_type() == doris::TExprNodeType::CAST_EXPR) {
+ // Expr `CAST(CAST(datetime_col AS DATE) AS DATETIME) = datetime_literal` should not be
+ // push down.
+ return false;
}
*range = &(entry->second.second);
return true;
@@ -705,7 +710,8 @@ Status VScanNode::_normalize_not_in_and_not_eq_predicate(VExpr* expr, VExprConte
ColumnValueRange<T>& range,
PushDownType* pdt) {
bool is_fixed_range = range.is_fixed_value_range();
- auto not_in_range = ColumnValueRange<T>::create_empty_column_value_range(range.column_name());
+ auto not_in_range = ColumnValueRange<T>::create_empty_column_value_range(
+ range.column_name(), slot->type().precision, slot->type().scale);
PushDownType temp_pdt = PushDownType::UNACCEPTABLE;
// 1. Normalize in conjuncts like 'where col in (v1, v2, v3)'
if (TExprNodeType::IN_PRED == expr->node_type()) {
diff --git a/be/src/vec/exprs/vexpr_context.cpp b/be/src/vec/exprs/vexpr_context.cpp
index ccb1045cb1..9033245202 100644
--- a/be/src/vec/exprs/vexpr_context.cpp
+++ b/be/src/vec/exprs/vexpr_context.cpp
@@ -110,6 +110,8 @@ int VExprContext::register_func(RuntimeState* state, const FunctionContext::Type
int varargs_buffer_size) {
_fn_contexts.push_back(FunctionContextImpl::create_context(
state, _pool.get(), return_type, arg_types, varargs_buffer_size, false));
+ _fn_contexts.back()->impl()->set_check_overflow_for_decimal(
+ state->check_overflow_for_decimal());
return _fn_contexts.size() - 1;
}
diff --git a/be/src/vec/functions/function_binary_arithmetic.h b/be/src/vec/functions/function_binary_arithmetic.h
index 2a8da748e3..d1c0375a55 100644
--- a/be/src/vec/functions/function_binary_arithmetic.h
+++ b/be/src/vec/functions/function_binary_arithmetic.h
@@ -23,6 +23,7 @@
#include <type_traits>
#include "runtime/decimalv2_value.h"
+#include "udf/udf_internal.h"
#include "vec/columns/column_const.h"
#include "vec/columns/column_decimal.h"
#include "vec/columns/column_nullable.h"
@@ -216,7 +217,8 @@ struct BinaryOperationImpl {
/// * no agrs scale. ScaleR = Scale1 + Scale2;
/// / first arg scale. ScaleR = Scale1 (scale_a = DecimalType<B>::get_scale()).
template <typename A, typename B, template <typename, typename> typename Operation,
- typename ResultType, bool is_to_null_type, bool check_overflow = true>
+ typename ResultType, bool is_to_null_type, bool return_nullable_type,
+ bool check_overflow = true>
struct DecimalBinaryOperation {
using OpTraits = OperationTraits<Operation>;
@@ -249,12 +251,14 @@ struct DecimalBinaryOperation {
for (size_t i = 0; i < size; ++i) {
c[i] = apply(a[i], b[i], null_map[i]);
}
- } else {
- if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
- for (size_t i = 0; i < size; ++i) {
- c[i] = apply_scaled_div(a[i], b[i], null_map[i]);
- }
- return;
+ } else if constexpr (OpTraits::is_division && (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
+ for (size_t i = 0; i < size; ++i) {
+ c[i] = apply_scaled_div(a[i], b[i], null_map[i]);
+ }
+ } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
+ (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
+ for (size_t i = 0; i < size; ++i) {
+ null_map[i] = apply_op_safely(a[i], b[i], c[i].value);
}
}
}
@@ -281,21 +285,21 @@ struct DecimalBinaryOperation {
for (size_t i = 0; i < size; ++i) {
c[i] = apply_scaled_div(a[i], b, null_map[i]);
}
- return;
- }
-
- for (size_t i = 0; i < size; ++i) {
- c[i] = apply(a[i], b, null_map[i]);
+ } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
+ (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
+ for (size_t i = 0; i < size; ++i) {
+ null_map[i] = apply_op_safely(a[i], b, c[i].value);
+ }
+ } else {
+ for (size_t i = 0; i < size; ++i) {
+ c[i] = apply(a[i], b, null_map[i]);
+ }
}
}
static void constant_vector(A a, const typename Traits::ArrayB& b, ArrayC& c) {
size_t size = b.size();
- if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
- for (size_t i = 0; i < size; ++i) {
- c[i] = apply_scaled_div(a, b[i]);
- }
- } else if constexpr (IsDecimalV2<A> || IsDecimalV2<B>) {
+ if constexpr (IsDecimalV2<A> || IsDecimalV2<B>) {
DecimalV2Value da(a);
for (size_t i = 0; i < size; ++i) {
c[i] = Op::template apply(da, DecimalV2Value(b[i])).value();
@@ -314,33 +318,43 @@ struct DecimalBinaryOperation {
for (size_t i = 0; i < size; ++i) {
c[i] = apply_scaled_div(a, b[i], null_map[i]);
}
- return;
- }
-
- for (size_t i = 0; i < size; ++i) {
- c[i] = apply(a, b[i], null_map[i]);
+ } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
+ (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
+ for (size_t i = 0; i < size; ++i) {
+ null_map[i] = apply_op_safely(a, b[i], c[i].value);
+ }
+ } else {
+ for (size_t i = 0; i < size; ++i) {
+ c[i] = apply(a, b[i], null_map[i]);
+ }
}
}
- static ResultType constant_constant(A a, B b) {
- if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
- return apply_scaled_div(a, b);
- }
- return apply(a, b);
- }
+ static ResultType constant_constant(A a, B b) { return apply(a, b); }
static ResultType constant_constant(A a, B b, UInt8& is_null) {
if constexpr (OpTraits::is_division && IsDecimalNumber<B>) {
return apply_scaled_div(a, b, is_null);
+ } else if constexpr ((OpTraits::is_multiply || OpTraits::is_plus_minus) &&
+ (IsDecimalNumber<B> || IsDecimalNumber<A>)) {
+ NativeResultType res;
+ is_null = apply_op_safely(a, b, res);
+ return res;
+ } else {
+ return apply(a, b, is_null);
}
- return apply(a, b, is_null);
}
static ColumnPtr adapt_decimal_constant_constant(A a, B b, DataTypePtr res_data_type) {
auto column_result = ColumnDecimal<ResultType>::create(
1, assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale());
- if constexpr (is_to_null_type) {
+ if constexpr (return_nullable_type && !is_to_null_type &&
+ ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
+ IsDecimalV2<B>)) {
+ LOG(FATAL) << "Invalid function type!";
+ return column_result;
+ } else if constexpr (return_nullable_type || is_to_null_type) {
auto null_map = ColumnUInt8::create(1, 0);
column_result->get_element(0) = constant_constant(a, b, null_map->get_element(0));
return ColumnNullable::create(std::move(column_result), std::move(null_map));
@@ -358,7 +372,12 @@ struct DecimalBinaryOperation {
assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale());
DCHECK(column_left_ptr != nullptr);
- if constexpr (is_to_null_type) {
+ if constexpr (return_nullable_type && !is_to_null_type &&
+ ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
+ IsDecimalV2<B>)) {
+ LOG(FATAL) << "Invalid function type!";
+ return column_result;
+ } else if constexpr (return_nullable_type || is_to_null_type) {
auto null_map = ColumnUInt8::create(column_left->size(), 0);
vector_constant(column_left_ptr->get_data(), b, column_result->get_data(),
null_map->get_data());
@@ -377,7 +396,12 @@ struct DecimalBinaryOperation {
assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale());
DCHECK(column_right_ptr != nullptr);
- if constexpr (is_to_null_type) {
+ if constexpr (return_nullable_type && !is_to_null_type &&
+ ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
+ IsDecimalV2<B>)) {
+ LOG(FATAL) << "Invalid function type!";
+ return column_result;
+ } else if constexpr (return_nullable_type || is_to_null_type) {
auto null_map = ColumnUInt8::create(column_right->size(), 0);
constant_vector(a, column_right_ptr->get_data(), column_result->get_data(),
null_map->get_data());
@@ -398,7 +422,12 @@ struct DecimalBinaryOperation {
assert_cast<const DataTypeDecimal<ResultType>&>(*res_data_type).get_scale());
DCHECK(column_left_ptr != nullptr && column_right_ptr != nullptr);
- if constexpr (is_to_null_type) {
+ if constexpr (return_nullable_type && !is_to_null_type &&
+ ((!OpTraits::is_multiply && !OpTraits::is_plus_minus) || IsDecimalV2<A> ||
+ IsDecimalV2<B>)) {
+ LOG(FATAL) << "Invalid function type!";
+ return column_result;
+ } else if constexpr (return_nullable_type || is_to_null_type) {
auto null_map = ColumnUInt8::create(column_result->size(), 0);
vector_vector(column_left_ptr->get_data(), column_right_ptr->get_data(),
column_result->get_data(), null_map->get_data());
@@ -483,6 +512,12 @@ private:
UInt8& is_null) {
return apply(a, b, is_null);
}
+
+ static UInt8 apply_op_safely(NativeResultType a, NativeResultType b, NativeResultType& c) {
+ if constexpr (OpTraits::is_multiply || OpTraits::is_plus_minus) {
+ return Op::template apply(a, b, c);
+ }
+ }
};
/// Used to indicate undefined operation
@@ -568,7 +603,8 @@ struct BinaryOperationTraits {
};
template <typename LeftDataType, typename RightDataType, typename ExpectedResultDataType,
- template <typename, typename> class Operation, bool is_to_null_type>
+ template <typename, typename> class Operation, bool is_to_null_type,
+ bool return_nullable_type>
struct ConstOrVectorAdapter {
static constexpr bool result_is_decimal =
IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>;
@@ -580,7 +616,8 @@ struct ConstOrVectorAdapter {
using OperationImpl = std::conditional_t<
IsDataTypeDecimal<ResultDataType>,
- DecimalBinaryOperation<A, B, Operation, ResultType, is_to_null_type>,
+ DecimalBinaryOperation<A, B, Operation, ResultType, is_to_null_type,
+ return_nullable_type>,
BinaryOperationImpl<A, B, Operation<A, B>, is_to_null_type, ResultType>>;
static ColumnPtr execute(ColumnPtr column_left, ColumnPtr column_right,
@@ -774,6 +811,7 @@ public:
right_generic =
static_cast<const DataTypeNullable*>(right_generic)->get_nested_type().get();
}
+ bool result_is_nullable = context->impl()->check_overflow_for_decimal();
if (result_generic->is_nullable()) {
result_generic =
static_cast<const DataTypeNullable*>(result_generic)->get_nested_type().get();
@@ -795,15 +833,31 @@ public:
ResultDataType>)&&(IsDataTypeDecimal<ExpectedResultDataType> ==
(IsDataTypeDecimal<LeftDataType> ||
IsDataTypeDecimal<RightDataType>))) {
- auto column_result = ConstOrVectorAdapter<
- LeftDataType, RightDataType,
- std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>,
- ExpectedResultDataType, ResultDataType>,
- Operation, is_to_null_type>::
- execute(block.get_by_position(arguments[0]).column,
- block.get_by_position(arguments[1]).column, left, right,
- remove_nullable(block.get_by_position(result).type));
- block.replace_by_position(result, std::move(column_result));
+ if (result_is_nullable) {
+ auto column_result = ConstOrVectorAdapter<
+ LeftDataType, RightDataType,
+ std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>,
+ ExpectedResultDataType, ResultDataType>,
+ Operation, is_to_null_type,
+ true>::execute(block.get_by_position(arguments[0]).column,
+ block.get_by_position(arguments[1]).column, left,
+ right,
+ remove_nullable(
+ block.get_by_position(result).type));
+ block.replace_by_position(result, std::move(column_result));
+ } else {
+ auto column_result = ConstOrVectorAdapter<
+ LeftDataType, RightDataType,
+ std::conditional_t<IsDataTypeDecimal<ExpectedResultDataType>,
+ ExpectedResultDataType, ResultDataType>,
+ Operation, is_to_null_type,
+ false>::execute(block.get_by_position(arguments[0]).column,
+ block.get_by_position(arguments[1]).column,
+ left, right,
+ remove_nullable(
+ block.get_by_position(result).type));
+ block.replace_by_position(result, std::move(column_result));
+ }
return true;
}
return false;
diff --git a/be/src/vec/functions/function_cast.h b/be/src/vec/functions/function_cast.h
index 90edd3906b..e3baaecdd2 100644
--- a/be/src/vec/functions/function_cast.h
+++ b/be/src/vec/functions/function_cast.h
@@ -22,6 +22,7 @@
#include <fmt/format.h>
+#include "udf/udf_internal.h"
#include "vec/columns/column_array.h"
#include "vec/columns/column_const.h"
#include "vec/columns/column_nullable.h"
@@ -72,7 +73,7 @@ struct ConvertImpl {
template <typename Additions = void*>
static Status execute(Block& block, const ColumnNumbers& arguments, size_t result,
- size_t /*input_rows_count*/,
+ size_t /*input_rows_count*/, bool check_overflow [[maybe_unused]] = false,
Additions additions [[maybe_unused]] = Additions()) {
const ColumnWithTypeAndName& named_from = block.get_by_position(arguments[0]);
@@ -96,37 +97,50 @@ struct ConvertImpl {
if constexpr (IsDataTypeDecimal<ToDataType>) {
UInt32 scale = additions;
col_to = ColVecTo::create(0, scale);
- } else
+ } else {
col_to = ColVecTo::create();
+ }
const auto& vec_from = col_from->get_data();
auto& vec_to = col_to->get_data();
size_t size = vec_from.size();
vec_to.resize(size);
- for (size_t i = 0; i < size; ++i) {
- if constexpr (IsDataTypeDecimal<FromDataType> || IsDataTypeDecimal<ToDataType>) {
- if constexpr (IsDataTypeDecimal<FromDataType> && IsDataTypeDecimal<ToDataType>)
+ if constexpr (IsDataTypeDecimal<FromDataType> || IsDataTypeDecimal<ToDataType>) {
+ ColumnUInt8::MutablePtr col_null_map_to = nullptr;
+ UInt8* vec_null_map_to = nullptr;
+ if (check_overflow) {
+ col_null_map_to = ColumnUInt8::create(size, 0);
+ vec_null_map_to = col_null_map_to->get_data().data();
+ }
+ for (size_t i = 0; i < size; ++i) {
+ if constexpr (IsDataTypeDecimal<FromDataType> &&
+ IsDataTypeDecimal<ToDataType>) {
vec_to[i] = convert_decimals<FromDataType, ToDataType>(
- vec_from[i], vec_from.get_scale(), vec_to.get_scale());
- else if constexpr (IsDataTypeDecimal<FromDataType> &&
- IsDataTypeNumber<ToDataType>)
+ vec_from[i], vec_from.get_scale(), vec_to.get_scale(),
+ vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to);
+ } else if constexpr (IsDataTypeDecimal<FromDataType> &&
+ IsDataTypeNumber<ToDataType>) {
vec_to[i] = convert_from_decimal<FromDataType, ToDataType>(
vec_from[i], vec_from.get_scale());
- else if constexpr (IsDataTypeNumber<FromDataType> &&
- IsDataTypeDecimal<ToDataType>)
+ } else if constexpr (IsDataTypeNumber<FromDataType> &&
+ IsDataTypeDecimal<ToDataType>) {
vec_to[i] = convert_to_decimal<FromDataType, ToDataType>(
- vec_from[i], vec_to.get_scale());
- else if constexpr (IsTimeType<FromDataType> && IsDataTypeDecimal<ToDataType>) {
+ vec_from[i], vec_to.get_scale(),
+ vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to);
+ } else if constexpr (IsTimeType<FromDataType> &&
+ IsDataTypeDecimal<ToDataType>) {
vec_to[i] = convert_to_decimal<DataTypeInt64, ToDataType>(
reinterpret_cast<const VecDateTimeValue&>(vec_from[i]).to_int64(),
- vec_to.get_scale());
+ vec_to.get_scale(),
+ vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to);
} else if constexpr (IsDateV2Type<FromDataType> &&
IsDataTypeDecimal<ToDataType>) {
vec_to[i] = convert_to_decimal<DataTypeUInt32, ToDataType>(
reinterpret_cast<const DateV2Value<DateV2ValueType>&>(vec_from[i])
.to_date_int_val(),
- vec_to.get_scale());
+ vec_to.get_scale(),
+ vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to);
} else if constexpr (IsDateTimeV2Type<FromDataType> &&
IsDataTypeDecimal<ToDataType>) {
// TODO: should we consider the scale of datetimev2?
@@ -134,9 +148,21 @@ struct ConvertImpl {
reinterpret_cast<const DateV2Value<DateTimeV2ValueType>&>(
vec_from[i])
.to_date_int_val(),
- vec_to.get_scale());
+ vec_to.get_scale(),
+ vec_null_map_to ? &vec_null_map_to[i] : vec_null_map_to);
}
- } else if constexpr (IsTimeType<FromDataType>) {
+ }
+ if (check_overflow) {
+ block.replace_by_position(
+ result,
+ ColumnNullable::create(std::move(col_to), std::move(col_null_map_to)));
+ } else {
+ block.replace_by_position(result, std::move(col_to));
+ }
+
+ return Status::OK();
+ } else if constexpr (IsTimeType<FromDataType>) {
+ for (size_t i = 0; i < size; ++i) {
if constexpr (IsTimeType<ToDataType>) {
vec_to[i] = static_cast<ToFieldType>(vec_from[i]);
if constexpr (IsDateTimeType<ToDataType>) {
@@ -152,7 +178,9 @@ struct ConvertImpl {
vec_to[i] =
reinterpret_cast<const VecDateTimeValue&>(vec_from[i]).to_int64();
}
- } else if constexpr (IsTimeV2Type<FromDataType>) {
+ }
+ } else if constexpr (IsTimeV2Type<FromDataType>) {
+ for (size_t i = 0; i < size; ++i) {
if constexpr (IsTimeV2Type<ToDataType>) {
if constexpr (IsDateTimeV2Type<ToDataType> && IsDateV2Type<FromDataType>) {
DataTypeDateV2::cast_to_date_time_v2(vec_from[i], vec_to[i]);
@@ -189,7 +217,9 @@ struct ConvertImpl {
.to_int64();
}
}
- } else {
+ }
+ } else {
+ for (size_t i = 0; i < size; ++i) {
vec_to[i] = static_cast<ToFieldType>(vec_from[i]);
}
}
@@ -547,7 +577,7 @@ struct ConvertImpl<DataTypeString, ToDataType, Name> {
template <typename Additions = void*>
static Status execute(Block& block, const ColumnNumbers& arguments, size_t result,
- size_t /*input_rows_count*/,
+ size_t /*input_rows_count*/, bool check_overflow [[maybe_unused]] = false,
Additions additions [[maybe_unused]] = Additions()) {
return Status::RuntimeError("not support convert from string");
}
@@ -832,19 +862,6 @@ public:
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) override {
- return executeInternal(block, arguments, result, input_rows_count);
- }
-
- bool has_information_about_monotonicity() const override { return Monotonic::has(); }
-
- Monotonicity get_monotonicity_for_range(const IDataType& type, const Field& left,
- const Field& right) const override {
- return Monotonic::get(type, left, right);
- }
-
-private:
- Status executeInternal(Block& block, const ColumnNumbers& arguments, size_t result,
- size_t input_rows_count) {
if (!arguments.size()) {
return Status::RuntimeError("Function {} expects at least 1 arguments", get_name());
}
@@ -873,13 +890,15 @@ private:
UInt32 scale = extract_to_decimal_scale(scale_column);
ret_status = ConvertImpl<LeftDataType, RightDataType, Name>::execute(
- block, arguments, result, input_rows_count, scale);
+ block, arguments, result, input_rows_count,
+ context->impl()->check_overflow_for_decimal(), scale);
} else if constexpr (IsDataTypeDateTimeV2<RightDataType>) {
const ColumnWithTypeAndName& scale_column = block.get_by_position(result);
auto type =
check_and_get_data_type<DataTypeDateTimeV2>(scale_column.type.get());
ret_status = ConvertImpl<LeftDataType, RightDataType, Name>::execute(
- block, arguments, result, input_rows_count, type->get_scale());
+ block, arguments, result, input_rows_count,
+ context->impl()->check_overflow_for_decimal(), type->get_scale());
} else {
ret_status = ConvertImpl<LeftDataType, RightDataType, Name>::execute(
block, arguments, result, input_rows_count);
@@ -896,6 +915,13 @@ private:
return ret_status;
}
}
+
+ bool has_information_about_monotonicity() const override { return Monotonic::has(); }
+
+ Monotonicity get_monotonicity_for_range(const IDataType& type, const Field& left,
+ const Field& right) const override {
+ return Monotonic::get(type, left, right);
+ }
};
using FunctionToUInt8 = FunctionConvert<DataTypeUInt8, NameToUInt8, ToNumberMonotonicity<UInt8>>;
@@ -1055,7 +1081,7 @@ struct ConvertThroughParsing {
template <typename Additions = void*>
static Status execute(Block& block, const ColumnNumbers& arguments, size_t result,
- size_t input_rows_count,
+ size_t input_rows_count, bool check_overflow [[maybe_unused]] = false,
Additions additions [[maybe_unused]] = Additions()) {
using ColVecTo = std::conditional_t<IsDecimalNumber<ToFieldType>,
ColumnDecimal<ToFieldType>, ColumnVector<ToFieldType>>;
@@ -1254,7 +1280,8 @@ public:
const ColumnNumbers& /*arguments*/,
size_t /*result*/) const override {
return std::make_shared<PreparedFunctionCast>(
- prepare_unpack_dictionaries(get_argument_types()[0], get_return_type()), name);
+ prepare_unpack_dictionaries(context, get_argument_types()[0], get_return_type()),
+ name);
}
String get_name() const override { return name; }
@@ -1347,7 +1374,8 @@ private:
using RightDataType = typename Types::RightType;
ConvertImpl<LeftDataType, RightDataType, NameCast>::execute(
- block, arguments, result, input_rows_count, scale);
+ block, arguments, result, input_rows_count,
+ context->impl()->check_overflow_for_decimal(), scale);
return true;
});
@@ -1396,7 +1424,7 @@ private:
return create_unsupport_wrapper(error_msg);
}
- WrapperType create_array_wrapper(const DataTypePtr& from_type_untyped,
+ WrapperType create_array_wrapper(FunctionContext* context, const DataTypePtr& from_type_untyped,
const DataTypeArray& to_type) const {
/// Conversion from String through parsing.
if (check_and_get_data_type<DataTypeString>(from_type_untyped.get())) {
@@ -1425,7 +1453,8 @@ private:
const DataTypePtr& to_nested_type = to_type.get_nested_type();
/// Prepare nested type conversion
- const auto nested_function = prepare_unpack_dictionaries(from_nested_type, to_nested_type);
+ const auto nested_function =
+ prepare_unpack_dictionaries(context, from_nested_type, to_nested_type);
return [nested_function, from_nested_type, to_nested_type](
FunctionContext* context, Block& block, const ColumnNumbers& arguments,
@@ -1513,7 +1542,7 @@ private:
}
}
- WrapperType prepare_unpack_dictionaries(const DataTypePtr& from_type,
+ WrapperType prepare_unpack_dictionaries(FunctionContext* context, const DataTypePtr& from_type,
const DataTypePtr& to_type) const {
const auto& from_nested = from_type;
const auto& to_nested = to_type;
@@ -1534,18 +1563,20 @@ private:
constexpr bool skip_not_null_check = false;
- auto wrapper = prepare_remove_nullable(from_nested, to_nested, skip_not_null_check);
+ auto wrapper =
+ prepare_remove_nullable(context, from_nested, to_nested, skip_not_null_check);
return wrapper;
}
- WrapperType prepare_remove_nullable(const DataTypePtr& from_type, const DataTypePtr& to_type,
+ WrapperType prepare_remove_nullable(FunctionContext* context, const DataTypePtr& from_type,
+ const DataTypePtr& to_type,
bool skip_not_null_check) const {
/// Determine whether pre-processing and/or post-processing must take place during conversion.
bool source_is_nullable = from_type->is_nullable();
bool result_is_nullable = to_type->is_nullable();
- auto wrapper = prepare_impl(remove_nullable(from_type), remove_nullable(to_type),
+ auto wrapper = prepare_impl(context, remove_nullable(from_type), remove_nullable(to_type),
result_is_nullable);
if (result_is_nullable) {
@@ -1620,8 +1651,8 @@ private:
/// 'from_type' and 'to_type' are nested types in case of Nullable.
/// 'requested_result_is_nullable' is true if CAST to Nullable type is requested.
- WrapperType prepare_impl(const DataTypePtr& from_type, const DataTypePtr& to_type,
- bool requested_result_is_nullable) const {
+ WrapperType prepare_impl(FunctionContext* context, const DataTypePtr& from_type,
+ const DataTypePtr& to_type, bool requested_result_is_nullable) const {
if (from_type->equals(*to_type))
return create_identity_wrapper(from_type);
else if (WhichDataType(from_type).is_nothing())
@@ -1679,7 +1710,8 @@ private:
case TypeIndex::String:
return create_string_wrapper(from_type);
case TypeIndex::Array:
- return create_array_wrapper(from_type, static_cast<const DataTypeArray&>(*to_type));
+ return create_array_wrapper(context, from_type,
+ static_cast<const DataTypeArray&>(*to_type));
default:
break;
}
diff --git a/be/src/vec/functions/function_date_or_datetime_computation.cpp b/be/src/vec/functions/function_date_or_datetime_computation.cpp
index 02dd86b58e..abaef68c4c 100644
--- a/be/src/vec/functions/function_date_or_datetime_computation.cpp
+++ b/be/src/vec/functions/function_date_or_datetime_computation.cpp
@@ -21,63 +21,48 @@
namespace doris::vectorized {
-using FunctionAddSeconds = FunctionDateOrDateTimeComputation<
- AddSecondsImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-using FunctionAddMinutes = FunctionDateOrDateTimeComputation<
- AddMinutesImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-using FunctionAddHours =
- FunctionDateOrDateTimeComputation<AddHoursImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-using FunctionAddDays =
- FunctionDateOrDateTimeComputation<AddDaysImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-using FunctionAddWeeks =
- FunctionDateOrDateTimeComputation<AddWeeksImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-using FunctionAddMonths =
- FunctionDateOrDateTimeComputation<AddMonthsImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-using FunctionAddQuarters = FunctionDateOrDateTimeComputation<
- AddQuartersImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-using FunctionAddYears =
- FunctionDateOrDateTimeComputation<AddYearsImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-
-using FunctionSubSeconds = FunctionDateOrDateTimeComputation<
- SubtractSecondsImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-using FunctionSubMinutes = FunctionDateOrDateTimeComputation<
- SubtractMinutesImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-using FunctionSubHours = FunctionDateOrDateTimeComputation<
- SubtractHoursImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-using FunctionSubDays = FunctionDateOrDateTimeComputation<
- SubtractDaysImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-using FunctionSubWeeks = FunctionDateOrDateTimeComputation<
- SubtractWeeksImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-using FunctionSubMonths = FunctionDateOrDateTimeComputation<
- SubtractMonthsImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-using FunctionSubQuarters = FunctionDateOrDateTimeComputation<
- SubtractQuartersImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-using FunctionSubYears = FunctionDateOrDateTimeComputation<
- SubtractYearsImpl<DataTypeDateTime, Int64, DataTypeDateTime>>;
-
-using FunctionDateDiff = FunctionDateOrDateTimeComputation<DateDiffImpl<
- VecDateTimeValue, VecDateTimeValue, DataTypeDateTime, DataTypeDateTime, Int64, Int64>>;
-using FunctionTimeDiff = FunctionDateOrDateTimeComputation<TimeDiffImpl<
- VecDateTimeValue, VecDateTimeValue, DataTypeDateTime, DataTypeDateTime, Int64, Int64>>;
-using FunctionYearsDiff = FunctionDateOrDateTimeComputation<YearsDiffImpl<
- VecDateTimeValue, VecDateTimeValue, DataTypeDateTime, DataTypeDateTime, Int64, Int64>>;
-using FunctionMonthsDiff = FunctionDateOrDateTimeComputation<MonthsDiffImpl<
- VecDateTimeValue, VecDateTimeValue, DataTypeDateTime, DataTypeDateTime, Int64, Int64>>;
-using FunctionDaysDiff = FunctionDateOrDateTimeComputation<DaysDiffImpl<
- VecDateTimeValue, VecDateTimeValue, DataTypeDateTime, DataTypeDateTime, Int64, Int64>>;
-using FunctionWeeksDiff = FunctionDateOrDateTimeComputation<WeeksDiffImpl<
- VecDateTimeValue, VecDateTimeValue, DataTypeDateTime, DataTypeDateTime, Int64, Int64>>;
-using FunctionHoursDiff = FunctionDateOrDateTimeComputation<HoursDiffImpl<
- VecDateTimeValue, VecDateTimeValue, DataTypeDateTime, DataTypeDateTime, Int64, Int64>>;
-using FunctionMinutesDiff = FunctionDateOrDateTimeComputation<MintueSDiffImpl<
- VecDateTimeValue, VecDateTimeValue, DataTypeDateTime, DataTypeDateTime, Int64, Int64>>;
-using FunctionSecondsDiff = FunctionDateOrDateTimeComputation<SecondsDiffImpl<
- VecDateTimeValue, VecDateTimeValue, DataTypeDateTime, DataTypeDateTime, Int64, Int64>>;
-
-using FunctionToYearWeekTwoArgs = FunctionDateOrDateTimeComputation<
- ToYearWeekTwoArgsImpl<VecDateTimeValue, DataTypeDateTime, Int64>>;
-using FunctionToWeekTwoArgs = FunctionDateOrDateTimeComputation<
- ToWeekTwoArgsImpl<VecDateTimeValue, DataTypeDateTime, Int64>>;
+using FunctionAddSeconds = FunctionDateOrDateTimeComputation<AddSecondsImpl<DataTypeDateTime>>;
+using FunctionAddMinutes = FunctionDateOrDateTimeComputation<AddMinutesImpl<DataTypeDateTime>>;
+using FunctionAddHours = FunctionDateOrDateTimeComputation<AddHoursImpl<DataTypeDateTime>>;
+using FunctionAddDays = FunctionDateOrDateTimeComputation<AddDaysImpl<DataTypeDateTime>>;
+using FunctionAddWeeks = FunctionDateOrDateTimeComputation<AddWeeksImpl<DataTypeDateTime>>;
+using FunctionAddMonths = FunctionDateOrDateTimeComputation<AddMonthsImpl<DataTypeDateTime>>;
+using FunctionAddQuarters = FunctionDateOrDateTimeComputation<AddQuartersImpl<DataTypeDateTime>>;
+using FunctionAddYears = FunctionDateOrDateTimeComputation<AddYearsImpl<DataTypeDateTime>>;
+
+using FunctionSubSeconds = FunctionDateOrDateTimeComputation<SubtractSecondsImpl<DataTypeDateTime>>;
+using FunctionSubMinutes = FunctionDateOrDateTimeComputation<SubtractMinutesImpl<DataTypeDateTime>>;
+using FunctionSubHours = FunctionDateOrDateTimeComputation<SubtractHoursImpl<DataTypeDateTime>>;
+using FunctionSubDays = FunctionDateOrDateTimeComputation<SubtractDaysImpl<DataTypeDateTime>>;
+using FunctionSubWeeks = FunctionDateOrDateTimeComputation<SubtractWeeksImpl<DataTypeDateTime>>;
+using FunctionSubMonths = FunctionDateOrDateTimeComputation<SubtractMonthsImpl<DataTypeDateTime>>;
+using FunctionSubQuarters =
+ FunctionDateOrDateTimeComputation<SubtractQuartersImpl<DataTypeDateTime>>;
+using FunctionSubYears = FunctionDateOrDateTimeComputation<SubtractYearsImpl<DataTypeDateTime>>;
+
+using FunctionDateDiff =
+ FunctionDateOrDateTimeComputation<DateDiffImpl<DataTypeDateTime, DataTypeDateTime>>;
+using FunctionTimeDiff =
+ FunctionDateOrDateTimeComputation<TimeDiffImpl<DataTypeDateTime, DataTypeDateTime>>;
+using FunctionYearsDiff =
+ FunctionDateOrDateTimeComputation<YearsDiffImpl<DataTypeDateTime, DataTypeDateTime>>;
+using FunctionMonthsDiff =
+ FunctionDateOrDateTimeComputation<MonthsDiffImpl<DataTypeDateTime, DataTypeDateTime>>;
+using FunctionDaysDiff =
+ FunctionDateOrDateTimeComputation<DaysDiffImpl<DataTypeDateTime, DataTypeDateTime>>;
+using FunctionWeeksDiff =
+ FunctionDateOrDateTimeComputation<WeeksDiffImpl<DataTypeDateTime, DataTypeDateTime>>;
+using FunctionHoursDiff =
+ FunctionDateOrDateTimeComputation<HoursDiffImpl<DataTypeDateTime, DataTypeDateTime>>;
+using FunctionMinutesDiff =
+ FunctionDateOrDateTimeComputation<MintueSDiffImpl<DataTypeDateTime, DataTypeDateTime>>;
+using FunctionSecondsDiff =
+ FunctionDateOrDateTimeComputation<SecondsDiffImpl<DataTypeDateTime, DataTypeDateTime>>;
+
+using FunctionToYearWeekTwoArgs =
+ FunctionDateOrDateTimeComputation<ToYearWeekTwoArgsImpl<DataTypeDateTime>>;
+using FunctionToWeekTwoArgs =
+ FunctionDateOrDateTimeComputation<ToWeekTwoArgsImpl<DataTypeDateTime>>;
struct NowFunctionName {
static constexpr auto name = "now";
diff --git a/be/src/vec/functions/function_date_or_datetime_computation.h b/be/src/vec/functions/function_date_or_datetime_computation.h
index a22bc1a357..eba6d17fc7 100644
--- a/be/src/vec/functions/function_date_or_datetime_computation.h
+++ b/be/src/vec/functions/function_date_or_datetime_computation.h
@@ -27,13 +27,14 @@
#include "vec/data_types/data_type_date.h"
#include "vec/data_types/data_type_date_time.h"
#include "vec/data_types/data_type_number.h"
+#include "vec/data_types/data_type_time.h"
#include "vec/functions/function.h"
#include "vec/functions/function_helpers.h"
#include "vec/runtime/vdatetime_value.h"
namespace doris::vectorized {
-template <TimeUnit unit, typename Arg, typename DateValueType, typename ResultDateValueType,
- typename ResultType>
+template <TimeUnit unit, typename DateValueType, typename ResultDateValueType, typename ResultType,
+ typename Arg>
extern ResultType date_time_add(const Arg& t, Int64 delta, bool& is_null) {
auto ts_value = binary_cast<Arg, DateValueType>(t);
TimeInterval interval(unit, delta, false);
@@ -51,9 +52,20 @@ extern ResultType date_time_add(const Arg& t, Int64 delta, bool& is_null) {
}
#define ADD_TIME_FUNCTION_IMPL(CLASS, NAME, UNIT) \
- template <typename DateType, typename ArgType, typename ResultType> \
+ template <typename DateType> \
struct CLASS { \
- using ReturnType = ResultType; \
+ using ReturnType = std::conditional_t< \
+ std::is_same_v<DateType, DataTypeDate> || \
+ std::is_same_v<DateType, DataTypeDateTime>, \
+ DataTypeDateTime, \
+ std::conditional_t< \
+ std::is_same_v<DateType, DataTypeDateV2>, \
+ std::conditional_t<TimeUnit::UNIT == TimeUnit::HOUR || \
+ TimeUnit::UNIT == TimeUnit::MINUTE || \
+ TimeUnit::UNIT == TimeUnit::SECOND || \
+ TimeUnit::UNIT == TimeUnit::SECOND_MICROSECOND, \
+ DataTypeDateTimeV2, DataTypeDateV2>, \
+ DataTypeDateTimeV2>>; \
using ReturnNativeType = std::conditional_t< \
std::is_same_v<DateType, DataTypeDate> || \
std::is_same_v<DateType, DataTypeDateTime>, \
@@ -66,12 +78,18 @@ extern ResultType date_time_add(const Arg& t, Int64 delta, bool& is_null) {
TimeUnit::UNIT == TimeUnit::SECOND_MICROSECOND, \
UInt64, UInt32>, \
UInt64>>; \
+ using InputNativeType = std::conditional_t< \
+ std::is_same_v<DateType, DataTypeDate> || \
+ std::is_same_v<DateType, DataTypeDateTime>, \
+ Int64, \
+ std::conditional_t<std::is_same_v<DateType, DataTypeDateV2>, UInt32, UInt64>>; \
static constexpr auto name = #NAME; \
static constexpr auto is_nullable = true; \
- static inline ReturnNativeType execute(const ArgType& t, Int64 delta, bool& is_null) { \
+ static inline ReturnNativeType execute(const InputNativeType& t, Int64 delta, \
+ bool& is_null) { \
if constexpr (std::is_same_v<DateType, DataTypeDate> || \
std::is_same_v<DateType, DataTypeDateTime>) { \
- return date_time_add<TimeUnit::UNIT, ArgType, doris::vectorized::VecDateTimeValue, \
+ return date_time_add<TimeUnit::UNIT, doris::vectorized::VecDateTimeValue, \
doris::vectorized::VecDateTimeValue, ReturnNativeType>( \
t, delta, is_null); \
} else if constexpr (std::is_same_v<DateType, DataTypeDateV2>) { \
@@ -79,17 +97,17 @@ extern ResultType date_time_add(const Arg& t, Int64 delta, bool& is_null) {
TimeUnit::UNIT == TimeUnit::MINUTE || \
TimeUnit::UNIT == TimeUnit::SECOND || \
TimeUnit::UNIT == TimeUnit::SECOND_MICROSECOND) { \
- return date_time_add<TimeUnit::UNIT, ArgType, DateV2Value<DateV2ValueType>, \
+ return date_time_add<TimeUnit::UNIT, DateV2Value<DateV2ValueType>, \
DateV2Value<DateTimeV2ValueType>, ReturnNativeType>( \
t, delta, is_null); \
} else { \
- return date_time_add<TimeUnit::UNIT, ArgType, DateV2Value<DateV2ValueType>, \
+ return date_time_add<TimeUnit::UNIT, DateV2Value<DateV2ValueType>, \
DateV2Value<DateV2ValueType>, ReturnNativeType>(t, delta, \
is_null); \
} \
\
} else { \
- return date_time_add<TimeUnit::UNIT, ArgType, DateV2Value<DateTimeV2ValueType>, \
+ return date_time_add<TimeUnit::UNIT, DateV2Value<DateTimeV2ValueType>, \
DateV2Value<DateTimeV2ValueType>, ReturnNativeType>(t, delta, \
is_null); \
} \
@@ -108,25 +126,33 @@ ADD_TIME_FUNCTION_IMPL(AddWeeksImpl, weeks_add, WEEK);
ADD_TIME_FUNCTION_IMPL(AddMonthsImpl, months_add, MONTH);
ADD_TIME_FUNCTION_IMPL(AddYearsImpl, years_add, YEAR);
-template <typename DateType, typename ArgType, typename ResultType>
+template <typename DateType>
struct AddQuartersImpl {
- using ReturnType = ResultType;
+ using ReturnType =
+ std::conditional_t<std::is_same_v<DateType, DataTypeDate> ||
+ std::is_same_v<DateType, DataTypeDateTime>,
+ DataTypeDateTime,
+ std::conditional_t<std::is_same_v<DateType, DataTypeDateV2>,
+ DataTypeDateV2, DataTypeDateTimeV2>>;
+ using InputNativeType = std::conditional_t<
+ std::is_same_v<DateType, DataTypeDate> || std::is_same_v<DateType, DataTypeDateTime>,
+ Int64, std::conditional_t<std::is_same_v<DateType, DataTypeDateV2>, UInt32, UInt64>>;
using ReturnNativeType = std::conditional_t<
std::is_same_v<DateType, DataTypeDate> || std::is_same_v<DateType, DataTypeDateTime>,
Int64, std::conditional_t<std::is_same_v<DateType, DataTypeDateV2>, UInt32, UInt64>>;
static constexpr auto name = "quarters_add";
static constexpr auto is_nullable = true;
- static inline ReturnNativeType execute(const ArgType& t, Int64 delta, bool& is_null) {
+ static inline ReturnNativeType execute(const InputNativeType& t, Int64 delta, bool& is_null) {
if constexpr (std::is_same_v<DateType, DataTypeDate> ||
std::is_same_v<DateType, DataTypeDateTime>) {
- return date_time_add<TimeUnit::MONTH, ArgType, doris::vectorized::VecDateTimeValue,
+ return date_time_add<TimeUnit::MONTH, doris::vectorized::VecDateTimeValue,
doris::vectorized::VecDateTimeValue, ReturnNativeType>(t, delta,
is_null);
} else if constexpr (std::is_same_v<DateType, DataTypeDateV2>) {
- return date_time_add<TimeUnit::MONTH, ArgType, DateV2Value<DateV2ValueType>,
+ return date_time_add<TimeUnit::MONTH, DateV2Value<DateV2ValueType>,
DateV2Value<DateV2ValueType>, ReturnNativeType>(t, delta, is_null);
} else {
- return date_time_add<TimeUnit::MONTH, ArgType, DateV2Value<DateTimeV2ValueType>,
+ return date_time_add<TimeUnit::MONTH, DateV2Value<DateTimeV2ValueType>,
DateV2Value<DateTimeV2ValueType>, ReturnNativeType>(t, delta,
is_null);
}
@@ -135,11 +161,12 @@ struct AddQuartersImpl {
static DataTypes get_variadic_argument_types() { return {std::make_shared<DateType>()}; }
};
-template <typename Transform, typename DateType, typename ArgType, typename ResultType>
+template <typename Transform, typename DateType>
struct SubtractIntervalImpl {
- using ReturnType = ResultType;
+ using ReturnType = typename Transform::ReturnType;
+ using InputNativeType = typename Transform::InputNativeType;
static constexpr auto is_nullable = true;
- static inline Int64 execute(const ArgType& t, Int64 delta, bool& is_null) {
+ static inline Int64 execute(const InputNativeType& t, Int64 delta, bool& is_null) {
return Transform::execute(t, -delta, is_null);
}
@@ -148,108 +175,81 @@ struct SubtractIntervalImpl {
}
};
-template <typename DateType, typename ArgType, typename ResultType>
-struct SubtractSecondsImpl : SubtractIntervalImpl<AddSecondsImpl<DateType, ArgType, ResultType>,
- DateType, ArgType, ResultType> {
+template <typename DateType>
+struct SubtractSecondsImpl : SubtractIntervalImpl<AddSecondsImpl<DateType>, DateType> {
static constexpr auto name = "seconds_sub";
};
-template <typename DateType, typename ArgType, typename ResultType>
-struct SubtractMinutesImpl : SubtractIntervalImpl<AddMinutesImpl<DateType, ArgType, ResultType>,
- DateType, ArgType, ResultType> {
+template <typename DateType>
+struct SubtractMinutesImpl : SubtractIntervalImpl<AddMinutesImpl<DateType>, DateType> {
static constexpr auto name = "minutes_sub";
};
-template <typename DateType, typename ArgType, typename ResultType>
-struct SubtractHoursImpl : SubtractIntervalImpl<AddHoursImpl<DateType, ArgType, ResultType>,
- DateType, ArgType, ResultType> {
+template <typename DateType>
+struct SubtractHoursImpl : SubtractIntervalImpl<AddHoursImpl<DateType>, DateType> {
static constexpr auto name = "hours_sub";
};
-template <typename DateType, typename ArgType, typename ResultType>
-struct SubtractDaysImpl : SubtractIntervalImpl<AddDaysImpl<DateType, ArgType, ResultType>, DateType,
- ArgType, ResultType> {
+template <typename DateType>
+struct SubtractDaysImpl : SubtractIntervalImpl<AddDaysImpl<DateType>, DateType> {
static constexpr auto name = "days_sub";
};
-template <typename DateType, typename ArgType, typename ResultType>
-struct SubtractWeeksImpl : SubtractIntervalImpl<AddWeeksImpl<DateType, ArgType, ResultType>,
- DateType, ArgType, ResultType> {
+template <typename DateType>
+struct SubtractWeeksImpl : SubtractIntervalImpl<AddWeeksImpl<DateType>, DateType> {
static constexpr auto name = "weeks_sub";
};
-template <typename DateType, typename ArgType, typename ResultType>
-struct SubtractMonthsImpl : SubtractIntervalImpl<AddMonthsImpl<DateType, ArgType, ResultType>,
- DateType, ArgType, ResultType> {
+template <typename DateType>
+struct SubtractMonthsImpl : SubtractIntervalImpl<AddMonthsImpl<DateType>, DateType> {
static constexpr auto name = "months_sub";
};
-template <typename DateType, typename ArgType, typename ResultType>
-struct SubtractQuartersImpl : SubtractIntervalImpl<AddQuartersImpl<DateType, ArgType, ResultType>,
- DateType, ArgType, ResultType> {
+template <typename DateType>
+struct SubtractQuartersImpl : SubtractIntervalImpl<AddQuartersImpl<DateType>, DateType> {
static constexpr auto name = "quarters_sub";
};
-template <typename DateType, typename ArgType, typename ResultType>
-struct SubtractYearsImpl : SubtractIntervalImpl<AddYearsImpl<DateType, ArgType, ResultType>,
- DateType, ArgType, ResultType> {
+template <typename DateType>
+struct SubtractYearsImpl : SubtractIntervalImpl<AddYearsImpl<DateType>, DateType> {
static constexpr auto name = "years_sub";
};
-template <typename DateValueType1, typename DateValueType2, typename DateType1, typename DateType2,
- typename ArgType1, typename ArgType2>
-struct DateDiffImpl {
- using ReturnType = DataTypeInt32;
- static constexpr auto name = "datediff";
- static constexpr auto is_nullable = false;
- static inline Int32 execute(const ArgType1& t0, const ArgType2& t1, bool& is_null) {
- const auto& ts0 = reinterpret_cast<const DateValueType1&>(t0);
- const auto& ts1 = reinterpret_cast<const DateValueType2&>(t1);
- is_null = !ts0.is_valid_date() || !ts1.is_valid_date();
- return ts0.daynr() - ts1.daynr();
- }
-
- static DataTypes get_variadic_argument_types() {
- return {std::make_shared<DateType1>(), std::make_shared<DateType2>()};
- }
-};
-
-template <typename DateValueType1, typename DateValueType2, typename DateType1, typename DateType2,
- typename ArgType1, typename ArgType2>
-struct TimeDiffImpl {
- using ReturnType = DataTypeFloat64;
- static constexpr auto name = "timediff";
- static constexpr auto is_nullable = false;
- static inline double execute(const ArgType1& t0, const ArgType2& t1, bool& is_null) {
- const auto& ts0 = reinterpret_cast<const DateValueType1&>(t0);
- const auto& ts1 = reinterpret_cast<const DateValueType2&>(t1);
- is_null = !ts0.is_valid_date() || !ts1.is_valid_date();
- return ts0.second_diff(ts1);
- }
-
- static DataTypes get_variadic_argument_types() {
- return {std::make_shared<DateType1>(), std::make_shared<DateType2>()};
- }
-};
+#define DECLARE_DATE_FUNCTIONS(NAME, FN_NAME, RETURN_TYPE, STMT) \
+ template <typename DateType1, typename DateType2> \
+ struct NAME { \
+ using ArgType1 = std::conditional_t< \
+ std::is_same_v<DateType1, DataTypeDateV2>, UInt32, \
+ std::conditional_t<std::is_same_v<DateType1, DataTypeDateTimeV2>, UInt64, Int64>>; \
+ using ArgType2 = std::conditional_t< \
+ std::is_same_v<DateType2, DataTypeDateV2>, UInt32, \
+ std::conditional_t<std::is_same_v<DateType2, DataTypeDateTimeV2>, UInt64, Int64>>; \
+ using DateValueType1 = std::conditional_t< \
+ std::is_same_v<DateType1, DataTypeDateV2>, DateV2Value<DateV2ValueType>, \
+ std::conditional_t<std::is_same_v<DateType1, DataTypeDateTimeV2>, \
+ DateV2Value<DateTimeV2ValueType>, VecDateTimeValue>>; \
+ using DateValueType2 = std::conditional_t< \
+ std::is_same_v<DateType2, DataTypeDateV2>, DateV2Value<DateV2ValueType>, \
+ std::conditional_t<std::is_same_v<DateType2, DataTypeDateTimeV2>, \
+ DateV2Value<DateTimeV2ValueType>, VecDateTimeValue>>; \
+ using ReturnType = RETURN_TYPE; \
+ static constexpr auto name = #FN_NAME; \
+ static constexpr auto is_nullable = false; \
+ static inline Int32 execute(const ArgType1& t0, const ArgType2& t1, bool& is_null) { \
+ const auto& ts0 = reinterpret_cast<const DateValueType1&>(t0); \
+ const auto& ts1 = reinterpret_cast<const DateValueType2&>(t1); \
+ is_null = !ts0.is_valid_date() || !ts1.is_valid_date(); \
+ return STMT; \
+ } \
+ static DataTypes get_variadic_argument_types() { \
+ return {std::make_shared<DateType1>(), std::make_shared<DateType2>()}; \
+ } \
+ };
+DECLARE_DATE_FUNCTIONS(DateDiffImpl, datediff, DataTypeInt32, (ts0.daynr() - ts1.daynr()));
+DECLARE_DATE_FUNCTIONS(TimeDiffImpl, timediff, DataTypeTime, ts0.second_diff(ts1));
-#define TIME_DIFF_FUNCTION_IMPL(CLASS, NAME, UNIT) \
- template <typename DateValueType1, typename DateValueType2, typename DateType1, \
- typename DateType2, typename ArgType1, typename ArgType2> \
- struct CLASS { \
- using ReturnType = DataTypeInt64; \
- static constexpr auto name = #NAME; \
- static constexpr auto is_nullable = false; \
- static inline Int64 execute(const ArgType1& t0, const ArgType2& t1, bool& is_null) { \
- const auto& ts0 = reinterpret_cast<const DateValueType1&>(t0); \
- const auto& ts1 = reinterpret_cast<const DateValueType2&>(t1); \
- is_null = !ts0.is_valid_date() || !ts1.is_valid_date(); \
- return datetime_diff<TimeUnit::UNIT>(ts1, ts0); \
- } \
- \
- static DataTypes get_variadic_argument_types() { \
- return {std::make_shared<DateType1>(), std::make_shared<DateType2>()}; \
- } \
- }
+#define TIME_DIFF_FUNCTION_IMPL(CLASS, NAME, UNIT) \
+ DECLARE_DATE_FUNCTIONS(CLASS, NAME, DataTypeInt64, datetime_diff<TimeUnit::UNIT>(ts1, ts0))
TIME_DIFF_FUNCTION_IMPL(YearsDiffImpl, years_diff, YEAR);
TIME_DIFF_FUNCTION_IMPL(MonthsDiffImpl, months_diff, MONTH);
@@ -259,20 +259,27 @@ TIME_DIFF_FUNCTION_IMPL(HoursDiffImpl, hours_diff, HOUR);
TIME_DIFF_FUNCTION_IMPL(MintueSDiffImpl, minutes_diff, MINUTE);
TIME_DIFF_FUNCTION_IMPL(SecondsDiffImpl, seconds_diff, SECOND);
-#define TIME_FUNCTION_TWO_ARGS_IMPL(CLASS, NAME, FUNCTION) \
- template <typename DateValueType, typename DateType, typename ArgType> \
- struct CLASS { \
- using ReturnType = DataTypeInt32; \
- static constexpr auto name = #NAME; \
- static constexpr auto is_nullable = false; \
- static inline int64_t execute(const ArgType& t0, const Int32 mode, bool& is_null) { \
- const auto& ts0 = reinterpret_cast<const DateValueType&>(t0); \
- is_null = !ts0.is_valid_date(); \
- return ts0.FUNCTION; \
- } \
- static DataTypes get_variadic_argument_types() { \
- return {std::make_shared<DateType>(), std::make_shared<DataTypeInt32>()}; \
- } \
+#define TIME_FUNCTION_TWO_ARGS_IMPL(CLASS, NAME, FUNCTION) \
+ template <typename DateType> \
+ struct CLASS { \
+ using ArgType = std::conditional_t< \
+ std::is_same_v<DateType, DataTypeDateV2>, UInt32, \
+ std::conditional_t<std::is_same_v<DateType, DataTypeDateTimeV2>, UInt64, Int64>>; \
+ using DateValueType = std::conditional_t< \
+ std::is_same_v<DateType, DataTypeDateV2>, DateV2Value<DateV2ValueType>, \
+ std::conditional_t<std::is_same_v<DateType, DataTypeDateTimeV2>, \
+ DateV2Value<DateTimeV2ValueType>, VecDateTimeValue>>; \
+ using ReturnType = DataTypeInt32; \
+ static constexpr auto name = #NAME; \
+ static constexpr auto is_nullable = false; \
+ static inline int64_t execute(const ArgType& t0, const Int32 mode, bool& is_null) { \
+ const auto& ts0 = reinterpret_cast<const DateValueType&>(t0); \
+ is_null = !ts0.is_valid_date(); \
+ return ts0.FUNCTION; \
+ } \
+ static DataTypes get_variadic_argument_types() { \
+ return {std::make_shared<DateType>(), std::make_shared<DataTypeInt32>()}; \
+ } \
}
TIME_FUNCTION_TWO_ARGS_IMPL(ToYearWeekTwoArgsImpl, yearweek, year_week(mysql_week_mode(mode)));
@@ -792,7 +799,7 @@ struct CurrentDateImpl {
template <typename FunctionName>
struct CurrentTimeImpl {
- using ReturnType = DataTypeFloat64;
+ using ReturnType = DataTypeTime;
static constexpr auto name = FunctionName::name;
static Status execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) {
diff --git a/be/src/vec/functions/function_date_or_datetime_computation_v2.cpp b/be/src/vec/functions/function_date_or_datetime_computation_v2.cpp
index ced2b2d70c..23d7f296dc 100644
--- a/be/src/vec/functions/function_date_or_datetime_computation_v2.cpp
+++ b/be/src/vec/functions/function_date_or_datetime_computation_v2.cpp
@@ -20,102 +20,76 @@
namespace doris::vectorized {
-using FunctionAddSecondsV2 = FunctionDateOrDateTimeComputation<
- AddSecondsImpl<DataTypeDateV2, UInt32, DataTypeDateTimeV2>>;
-using FunctionAddMinutesV2 = FunctionDateOrDateTimeComputation<
- AddMinutesImpl<DataTypeDateV2, UInt32, DataTypeDateTimeV2>>;
-using FunctionAddHoursV2 =
- FunctionDateOrDateTimeComputation<AddHoursImpl<DataTypeDateV2, UInt32, DataTypeDateTimeV2>>;
-using FunctionAddDaysV2 =
- FunctionDateOrDateTimeComputation<AddDaysImpl<DataTypeDateV2, UInt32, DataTypeDateV2>>;
-using FunctionAddWeeksV2 =
- FunctionDateOrDateTimeComputation<AddWeeksImpl<DataTypeDateV2, UInt32, DataTypeDateV2>>;
-using FunctionAddMonthsV2 =
- FunctionDateOrDateTimeComputation<AddMonthsImpl<DataTypeDateV2, UInt32, DataTypeDateV2>>;
-using FunctionAddQuartersV2 =
- FunctionDateOrDateTimeComputation<AddQuartersImpl<DataTypeDateV2, UInt32, DataTypeDateV2>>;
-using FunctionAddYearsV2 =
- FunctionDateOrDateTimeComputation<AddYearsImpl<DataTypeDateV2, UInt32, DataTypeDateV2>>;
-
-using FunctionSubSecondsV2 = FunctionDateOrDateTimeComputation<
- SubtractSecondsImpl<DataTypeDateV2, UInt32, DataTypeDateTimeV2>>;
-using FunctionSubMinutesV2 = FunctionDateOrDateTimeComputation<
- SubtractMinutesImpl<DataTypeDateV2, UInt32, DataTypeDateTimeV2>>;
-using FunctionSubHoursV2 = FunctionDateOrDateTimeComputation<
- SubtractHoursImpl<DataTypeDateV2, UInt32, DataTypeDateTimeV2>>;
-using FunctionSubDaysV2 =
- FunctionDateOrDateTimeComputation<SubtractDaysImpl<DataTypeDateV2, UInt32, DataTypeDateV2>>;
-using FunctionSubWeeksV2 = FunctionDateOrDateTimeComputation<
- SubtractWeeksImpl<DataTypeDateV2, UInt32, DataTypeDateV2>>;
-using FunctionSubMonthsV2 = FunctionDateOrDateTimeComputation<
- SubtractMonthsImpl<DataTypeDateV2, UInt32, DataTypeDateV2>>;
-using FunctionSubQuartersV2 = FunctionDateOrDateTimeComputation<
- SubtractQuartersImpl<DataTypeDateV2, UInt32, DataTypeDateV2>>;
-using FunctionSubYearsV2 = FunctionDateOrDateTimeComputation<
- SubtractYearsImpl<DataTypeDateV2, UInt32, DataTypeDateV2>>;
-
-using FunctionToYearWeekTwoArgsV2 = FunctionDateOrDateTimeComputation<
- ToYearWeekTwoArgsImpl<DateV2Value<DateV2ValueType>, DataTypeDateV2, UInt32>>;
-using FunctionToWeekTwoArgsV2 = FunctionDateOrDateTimeComputation<
- ToWeekTwoArgsImpl<DateV2Value<DateV2ValueType>, DataTypeDateV2, UInt32>>;
-
-using FunctionDatetimeV2AddSeconds = FunctionDateOrDateTimeComputation<
- AddSecondsImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-using FunctionDatetimeV2AddMinutes = FunctionDateOrDateTimeComputation<
- AddMinutesImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-using FunctionDatetimeV2AddHours = FunctionDateOrDateTimeComputation<
- AddHoursImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-using FunctionDatetimeV2AddDays = FunctionDateOrDateTimeComputation<
- AddDaysImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-using FunctionDatetimeV2AddWeeks = FunctionDateOrDateTimeComputation<
- AddWeeksImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-using FunctionDatetimeV2AddMonths = FunctionDateOrDateTimeComputation<
- AddMonthsImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-using FunctionDatetimeV2AddQuarters = FunctionDateOrDateTimeComputation<
- AddQuartersImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-using FunctionDatetimeV2AddYears = FunctionDateOrDateTimeComputation<
- AddYearsImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-
-using FunctionDatetimeV2SubSeconds = FunctionDateOrDateTimeComputation<
- SubtractSecondsImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-using FunctionDatetimeV2SubMinutes = FunctionDateOrDateTimeComputation<
- SubtractMinutesImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-using FunctionDatetimeV2SubHours = FunctionDateOrDateTimeComputation<
- SubtractHoursImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-using FunctionDatetimeV2SubDays = FunctionDateOrDateTimeComputation<
- SubtractDaysImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-using FunctionDatetimeV2SubWeeks = FunctionDateOrDateTimeComputation<
- SubtractWeeksImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-using FunctionDatetimeV2SubMonths = FunctionDateOrDateTimeComputation<
- SubtractMonthsImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-using FunctionDatetimeV2SubQuarters = FunctionDateOrDateTimeComputation<
- SubtractQuartersImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-using FunctionDatetimeV2SubYears = FunctionDateOrDateTimeComputation<
- SubtractYearsImpl<DataTypeDateTimeV2, UInt64, DataTypeDateTimeV2>>;
-
-#define FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, TYPE1, TYPE2, ARG1, ARG2, DATE_VALUE1, \
- DATE_VALUE2) \
- using NAME##_##TYPE1##_##TYPE2 = FunctionDateOrDateTimeComputation< \
- IMPL<DATE_VALUE1, DATE_VALUE2, TYPE1, TYPE2, ARG1, ARG2>>;
-
-#define ALL_FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL) \
- FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateTimeV2, DataTypeDateTimeV2, UInt64, \
- UInt64, DateV2Value<DateTimeV2ValueType>, \
- DateV2Value<DateTimeV2ValueType>) \
- FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateTimeV2, DataTypeDateV2, UInt64, UInt32, \
- DateV2Value<DateTimeV2ValueType>, DateV2Value<DateV2ValueType>) \
- FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateV2, DataTypeDateTimeV2, UInt32, UInt64, \
- DateV2Value<DateV2ValueType>, DateV2Value<DateTimeV2ValueType>) \
- FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateTimeV2, DataTypeDateTime, UInt64, Int64, \
- DateV2Value<DateTimeV2ValueType>, VecDateTimeValue) \
- FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateTime, DataTypeDateTimeV2, Int64, UInt64, \
- VecDateTimeValue, DateV2Value<DateTimeV2ValueType>) \
- FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateTime, DataTypeDateV2, Int64, UInt32, \
- VecDateTimeValue, DateV2Value<DateV2ValueType>) \
- FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateV2, DataTypeDateTime, UInt32, Int64, \
- DateV2Value<DateV2ValueType>, VecDateTimeValue) \
- FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateV2, DataTypeDateV2, UInt32, UInt32, \
- DateV2Value<DateV2ValueType>, DateV2Value<DateV2ValueType>)
+using FunctionAddSecondsV2 = FunctionDateOrDateTimeComputation<AddSecondsImpl<DataTypeDateV2>>;
+using FunctionAddMinutesV2 = FunctionDateOrDateTimeComputation<AddMinutesImpl<DataTypeDateV2>>;
+using FunctionAddHoursV2 = FunctionDateOrDateTimeComputation<AddHoursImpl<DataTypeDateV2>>;
+using FunctionAddDaysV2 = FunctionDateOrDateTimeComputation<AddDaysImpl<DataTypeDateV2>>;
+using FunctionAddWeeksV2 = FunctionDateOrDateTimeComputation<AddWeeksImpl<DataTypeDateV2>>;
+using FunctionAddMonthsV2 = FunctionDateOrDateTimeComputation<AddMonthsImpl<DataTypeDateV2>>;
+using FunctionAddQuartersV2 = FunctionDateOrDateTimeComputation<AddQuartersImpl<DataTypeDateV2>>;
+using FunctionAddYearsV2 = FunctionDateOrDateTimeComputation<AddYearsImpl<DataTypeDateV2>>;
+
+using FunctionSubSecondsV2 = FunctionDateOrDateTimeComputation<SubtractSecondsImpl<DataTypeDateV2>>;
+using FunctionSubMinutesV2 = FunctionDateOrDateTimeComputation<SubtractMinutesImpl<DataTypeDateV2>>;
+using FunctionSubHoursV2 = FunctionDateOrDateTimeComputation<SubtractHoursImpl<DataTypeDateV2>>;
+using FunctionSubDaysV2 = FunctionDateOrDateTimeComputation<SubtractDaysImpl<DataTypeDateV2>>;
+using FunctionSubWeeksV2 = FunctionDateOrDateTimeComputation<SubtractWeeksImpl<DataTypeDateV2>>;
+using FunctionSubMonthsV2 = FunctionDateOrDateTimeComputation<SubtractMonthsImpl<DataTypeDateV2>>;
+using FunctionSubQuartersV2 =
+ FunctionDateOrDateTimeComputation<SubtractQuartersImpl<DataTypeDateV2>>;
+using FunctionSubYearsV2 = FunctionDateOrDateTimeComputation<SubtractYearsImpl<DataTypeDateV2>>;
+
+using FunctionToYearWeekTwoArgsV2 =
+ FunctionDateOrDateTimeComputation<ToYearWeekTwoArgsImpl<DataTypeDateV2>>;
+using FunctionToWeekTwoArgsV2 =
+ FunctionDateOrDateTimeComputation<ToWeekTwoArgsImpl<DataTypeDateV2>>;
+
+using FunctionDatetimeV2AddSeconds =
+ FunctionDateOrDateTimeComputation<AddSecondsImpl<DataTypeDateTimeV2>>;
+using FunctionDatetimeV2AddMinutes =
+ FunctionDateOrDateTimeComputation<AddMinutesImpl<DataTypeDateTimeV2>>;
+using FunctionDatetimeV2AddHours =
+ FunctionDateOrDateTimeComputation<AddHoursImpl<DataTypeDateTimeV2>>;
+using FunctionDatetimeV2AddDays =
+ FunctionDateOrDateTimeComputation<AddDaysImpl<DataTypeDateTimeV2>>;
+using FunctionDatetimeV2AddWeeks =
+ FunctionDateOrDateTimeComputation<AddWeeksImpl<DataTypeDateTimeV2>>;
+using FunctionDatetimeV2AddMonths =
+ FunctionDateOrDateTimeComputation<AddMonthsImpl<DataTypeDateTimeV2>>;
+using FunctionDatetimeV2AddQuarters =
+ FunctionDateOrDateTimeComputation<AddQuartersImpl<DataTypeDateTimeV2>>;
+using FunctionDatetimeV2AddYears =
+ FunctionDateOrDateTimeComputation<AddYearsImpl<DataTypeDateTimeV2>>;
+
+using FunctionDatetimeV2SubSeconds =
+ FunctionDateOrDateTimeComputation<SubtractSecondsImpl<DataTypeDateTimeV2>>;
+using FunctionDatetimeV2SubMinutes =
+ FunctionDateOrDateTimeComputation<SubtractMinutesImpl<DataTypeDateTimeV2>>;
+using FunctionDatetimeV2SubHours =
+ FunctionDateOrDateTimeComputation<SubtractHoursImpl<DataTypeDateTimeV2>>;
+using FunctionDatetimeV2SubDays =
+ FunctionDateOrDateTimeComputation<SubtractDaysImpl<DataTypeDateTimeV2>>;
+using FunctionDatetimeV2SubWeeks =
+ FunctionDateOrDateTimeComputation<SubtractWeeksImpl<DataTypeDateTimeV2>>;
+using FunctionDatetimeV2SubMonths =
+ FunctionDateOrDateTimeComputation<SubtractMonthsImpl<DataTypeDateTimeV2>>;
+using FunctionDatetimeV2SubQuarters =
+ FunctionDateOrDateTimeComputation<SubtractQuartersImpl<DataTypeDateTimeV2>>;
+using FunctionDatetimeV2SubYears =
+ FunctionDateOrDateTimeComputation<SubtractYearsImpl<DataTypeDateTimeV2>>;
+
+#define FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, TYPE1, TYPE2) \
+ using NAME##_##TYPE1##_##TYPE2 = FunctionDateOrDateTimeComputation<IMPL<TYPE1, TYPE2>>;
+
+#define ALL_FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL) \
+ FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateTimeV2, DataTypeDateTimeV2) \
+ FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateTimeV2, DataTypeDateV2) \
+ FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateV2, DataTypeDateTimeV2) \
+ FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateTimeV2, DataTypeDateTime) \
+ FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateTime, DataTypeDateTimeV2) \
+ FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateTime, DataTypeDateV2) \
+ FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateV2, DataTypeDateTime) \
+ FUNCTION_DATEV2_WITH_TWO_ARGS(NAME, IMPL, DataTypeDateV2, DataTypeDateV2)
ALL_FUNCTION_DATEV2_WITH_TWO_ARGS(FunctionDatetimeV2DateDiff, DateDiffImpl)
ALL_FUNCTION_DATEV2_WITH_TWO_ARGS(FunctionDatetimeV2TimeDiff, TimeDiffImpl)
@@ -127,10 +101,10 @@ ALL_FUNCTION_DATEV2_WITH_TWO_ARGS(FunctionDatetimeV2MinutesDiff, MintueSDiffImpl
ALL_FUNCTION_DATEV2_WITH_TWO_ARGS(FunctionDatetimeV2SecondsDiff, SecondsDiffImpl)
ALL_FUNCTION_DATEV2_WITH_TWO_ARGS(FunctionDatetimeV2DaysDiff, DaysDiffImpl)
-using FunctionDatetimeV2ToYearWeekTwoArgs = FunctionDateOrDateTimeComputation<
- ToYearWeekTwoArgsImpl<DateV2Value<DateTimeV2ValueType>, DataTypeDateTimeV2, UInt64>>;
-using FunctionDatetimeV2ToWeekTwoArgs = FunctionDateOrDateTimeComputation<
- ToWeekTwoArgsImpl<DateV2Value<DateTimeV2ValueType>, DataTypeDateTimeV2, UInt64>>;
+using FunctionDatetimeV2ToYearWeekTwoArgs =
+ FunctionDateOrDateTimeComputation<ToYearWeekTwoArgsImpl<DataTypeDateTimeV2>>;
+using FunctionDatetimeV2ToWeekTwoArgs =
+ FunctionDateOrDateTimeComputation<ToWeekTwoArgsImpl<DataTypeDateTimeV2>>;
void register_function_date_time_computation_v2(SimpleFunctionFactory& factory) {
factory.register_function<FunctionAddSecondsV2>();
diff --git a/be/src/vec/functions/function_running_difference.h b/be/src/vec/functions/function_running_difference.h
index b9b53892f9..85f17ad786 100644
--- a/be/src/vec/functions/function_running_difference.h
+++ b/be/src/vec/functions/function_running_difference.h
@@ -31,6 +31,7 @@
#include "vec/data_types/data_type_date_time.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"
+#include "vec/data_types/data_type_time.h"
#include "vec/data_types/data_type_time_v2.h"
#include "vec/data_types/number_traits.h"
#include "vec/functions/function.h"
@@ -75,7 +76,7 @@ public:
} else if (which.is_decimal()) {
return_type = nested_type;
} else if (which.is_date_time() || which.is_date_time_v2()) {
- return_type = std::make_shared<DataTypeFloat64>();
+ return_type = std::make_shared<DataTypeTime>();
} else if (which.is_date() || which.is_date_v2()) {
return_type = std::make_shared<DataTypeInt32>();
}
diff --git a/be/src/vec/functions/function_string.h b/be/src/vec/functions/function_string.h
index fa6f2f9934..b37670808e 100644
--- a/be/src/vec/functions/function_string.h
+++ b/be/src/vec/functions/function_string.h
@@ -1936,9 +1936,8 @@ public:
ColumnPtr argument_column = block.get_by_position(arguments[0]).column;
auto result_column = assert_cast<ColumnString*>(res_column.get());
- auto data_column = assert_cast<const typename Impl::ColumnType*>(argument_column.get());
- Impl::execute(context, result_column, data_column, input_rows_count);
+ Impl::execute(context, result_column, argument_column, input_rows_count);
block.replace_by_position(result, std::move(res_column));
return Status::OK();
@@ -1946,12 +1945,11 @@ public:
};
struct MoneyFormatDoubleImpl {
- using ColumnType = ColumnVector<Float64>;
-
static DataTypes get_variadic_argument_types() { return {std::make_shared<DataTypeFloat64>()}; }
static void execute(FunctionContext* context, ColumnString* result_column,
- const ColumnType* data_column, size_t input_rows_count) {
+ const ColumnPtr col_ptr, size_t input_rows_count) {
+ const auto* data_column = assert_cast<const ColumnVector<Float64>*>(col_ptr.get());
for (size_t i = 0; i < input_rows_count; i++) {
double value =
MathFunctions::my_double_round(data_column->get_element(i), 2, false, false);
@@ -1962,12 +1960,11 @@ struct MoneyFormatDoubleImpl {
};
struct MoneyFormatInt64Impl {
- using ColumnType = ColumnVector<Int64>;
-
static DataTypes get_variadic_argument_types() { return {std::make_shared<DataTypeInt64>()}; }
static void execute(FunctionContext* context, ColumnString* result_column,
- const ColumnType* data_column, size_t input_rows_count) {
+ const ColumnPtr col_ptr, size_t input_rows_count) {
+ const auto* data_column = assert_cast<const ColumnVector<Int64>*>(col_ptr.get());
for (size_t i = 0; i < input_rows_count; i++) {
Int64 value = data_column->get_element(i);
StringVal str = StringFunctions::do_money_format<Int64, 26>(context, value);
@@ -1977,12 +1974,11 @@ struct MoneyFormatInt64Impl {
};
struct MoneyFormatInt128Impl {
- using ColumnType = ColumnVector<Int128>;
-
static DataTypes get_variadic_argument_types() { return {std::make_shared<DataTypeInt128>()}; }
static void execute(FunctionContext* context, ColumnString* result_column,
- const ColumnType* data_column, size_t input_rows_count) {
+ const ColumnPtr col_ptr, size_t input_rows_count) {
+ const auto* data_column = assert_cast<const ColumnVector<Int128>*>(col_ptr.get());
for (size_t i = 0; i < input_rows_count; i++) {
Int128 value = data_column->get_element(i);
StringVal str = StringFunctions::do_money_format<Int128, 52>(context, value);
@@ -1992,24 +1988,81 @@ struct MoneyFormatInt128Impl {
};
struct MoneyFormatDecimalImpl {
- using ColumnType = ColumnDecimal<Decimal128>;
-
static DataTypes get_variadic_argument_types() {
return {std::make_shared<DataTypeDecimal<Decimal128>>(27, 9)};
}
- static void execute(FunctionContext* context, ColumnString* result_column,
- const ColumnType* data_column, size_t input_rows_count) {
- for (size_t i = 0; i < input_rows_count; i++) {
- DecimalV2Val value = DecimalV2Val(data_column->get_element(i));
+ static void execute(FunctionContext* context, ColumnString* result_column, ColumnPtr col_ptr,
+ size_t input_rows_count) {
+ if (auto* decimalv2_column = check_and_get_column<ColumnDecimal<Decimal128>>(*col_ptr)) {
+ for (size_t i = 0; i < input_rows_count; i++) {
+ DecimalV2Val value = DecimalV2Val(decimalv2_column->get_element(i));
- DecimalV2Value rounded(0);
- DecimalV2Value::from_decimal_val(value).round(&rounded, 2, HALF_UP);
+ DecimalV2Value rounded(0);
+ DecimalV2Value::from_decimal_val(value).round(&rounded, 2, HALF_UP);
- StringVal str = StringFunctions::do_money_format<int64_t, 26>(
- context, rounded.int_value(), abs(rounded.frac_value() / 10000000));
+ StringVal str = StringFunctions::do_money_format<int64_t, 26>(
+ context, rounded.int_value(), abs(rounded.frac_value() / 10000000));
- result_column->insert_data(reinterpret_cast<const char*>(str.ptr), str.len);
+ result_column->insert_data(reinterpret_cast<const char*>(str.ptr), str.len);
+ }
+ } else if (auto* decimal32_column =
+ check_and_get_column<ColumnDecimal<Decimal32>>(*col_ptr)) {
+ const UInt32 scale = decimal32_column->get_scale();
+ const auto multiplier =
+ scale > 2 ? common::exp10_i32(scale - 2) : common::exp10_i32(2 - scale);
+ for (size_t i = 0; i < input_rows_count; i++) {
+ Decimal32 frac_part = decimal32_column->get_fractional_part(i);
+ if (scale > 2) {
+ int delta = ((frac_part % multiplier) << 1) > multiplier;
+ frac_part = frac_part / multiplier + delta;
+ } else if (scale < 2) {
+ frac_part = frac_part * multiplier;
+ }
+
+ StringVal str = StringFunctions::do_money_format<int64_t, 26>(
+ context, decimal32_column->get_whole_part(i), frac_part);
+
+ result_column->insert_data(reinterpret_cast<const char*>(str.ptr), str.len);
+ }
+ } else if (auto* decimal64_column =
+ check_and_get_column<ColumnDecimal<Decimal64>>(*col_ptr)) {
+ const UInt32 scale = decimal64_column->get_scale();
+ const auto multiplier =
+ scale > 2 ? common::exp10_i32(scale - 2) : common::exp10_i32(2 - scale);
+ for (size_t i = 0; i < input_rows_count; i++) {
+ Decimal64 frac_part = decimal64_column->get_fractional_part(i);
+ if (scale > 2) {
+ int delta = ((frac_part % multiplier) << 1) > multiplier;
+ frac_part = frac_part / multiplier + delta;
+ } else if (scale < 2) {
+ frac_part = frac_part * multiplier;
+ }
+
+ StringVal str = StringFunctions::do_money_format<int64_t, 26>(
+ context, decimal64_column->get_whole_part(i), frac_part);
+
+ result_column->insert_data(reinterpret_cast<const char*>(str.ptr), str.len);
+ }
+ } else if (auto* decimal128_column =
+ check_and_get_column<ColumnDecimal<Decimal128I>>(*col_ptr)) {
+ const UInt32 scale = decimal128_column->get_scale();
+ const auto multiplier =
+ scale > 2 ? common::exp10_i32(scale - 2) : common::exp10_i32(2 - scale);
+ for (size_t i = 0; i < input_rows_count; i++) {
+ Decimal128I frac_part = decimal128_column->get_fractional_part(i);
+ if (scale > 2) {
+ int delta = ((frac_part % multiplier) << 1) > multiplier;
+ frac_part = frac_part / multiplier + delta;
+ } else if (scale < 2) {
+ frac_part = frac_part * multiplier;
+ }
+
+ StringVal str = StringFunctions::do_money_format<int64_t, 26>(
+ context, decimal128_column->get_whole_part(i), frac_part);
+
+ result_column->insert_data(reinterpret_cast<const char*>(str.ptr), str.len);
+ }
}
}
};
diff --git a/be/src/vec/runtime/vdatetime_value.cpp b/be/src/vec/runtime/vdatetime_value.cpp
index 8982391764..8111569c6b 100644
--- a/be/src/vec/runtime/vdatetime_value.cpp
+++ b/be/src/vec/runtime/vdatetime_value.cpp
@@ -1893,6 +1893,7 @@ bool DateV2Value<T>::from_date_str(const char* date_str, int len, int scale) {
if (field_idx == 2 && *ptr == 'T') {
// YYYYMMDDTHHMMDD, skip 'T' and continue
ptr++;
+ field_idx++;
continue;
}
diff --git a/be/src/vec/sink/vmysql_result_writer.cpp b/be/src/vec/sink/vmysql_result_writer.cpp
index 2900d2d486..d563e0a0a0 100644
--- a/be/src/vec/sink/vmysql_result_writer.cpp
+++ b/be/src/vec/sink/vmysql_result_writer.cpp
@@ -351,8 +351,7 @@ int VMysqlResultWriter::_add_one_cell(const ColumnPtr& column_ptr, size_t row_id
} else if (which.is_date_or_datetime()) {
auto& column_vector = assert_cast<const ColumnVector<Int64>&>(*column);
auto value = column_vector[row_idx].get<Int64>();
- VecDateTimeValue datetime;
- memcpy(static_cast<void*>(&datetime), static_cast<void*>(&value), sizeof(value));
+ VecDateTimeValue datetime = binary_cast<Int64, VecDateTimeValue>(value);
if (which.is_date()) {
datetime.cast_to_date();
}
@@ -362,11 +361,19 @@ int VMysqlResultWriter::_add_one_cell(const ColumnPtr& column_ptr, size_t row_id
} else if (which.is_date_v2()) {
auto& column_vector = assert_cast<const ColumnVector<UInt32>&>(*column);
auto value = column_vector[row_idx].get<UInt32>();
- DateV2Value<DateV2ValueType> datev2;
- memcpy(static_cast<void*>(&datev2), static_cast<void*>(&value), sizeof(value));
+ DateV2Value<DateV2ValueType> datev2 =
+ binary_cast<UInt32, DateV2Value<DateV2ValueType>>(value);
char buf[64];
char* pos = datev2.to_string(buf);
return buffer.push_string(buf, pos - buf - 1);
+ } else if (which.is_date_time_v2()) {
+ auto& column_vector = assert_cast<const ColumnVector<UInt64>&>(*column);
+ auto value = column_vector[row_idx].get<UInt64>();
+ DateV2Value<DateTimeV2ValueType> datetimev2 =
+ binary_cast<UInt64, DateV2Value<DateTimeV2ValueType>>(value);
+ char buf[64];
+ char* pos = datetimev2.to_string(buf);
+ return buffer.push_string(buf, pos - buf - 1);
} else if (which.is_decimal32()) {
DataTypePtr nested_type = type;
if (type->is_nullable()) {
diff --git a/be/test/vec/function/function_running_difference_test.cpp b/be/test/vec/function/function_running_difference_test.cpp
index 6b002097b7..0245db4523 100644
--- a/be/test/vec/function/function_running_difference_test.cpp
+++ b/be/test/vec/function/function_running_difference_test.cpp
@@ -55,7 +55,7 @@ TEST(FunctionRunningDifferenceTest, function_running_difference_test) {
{{std::string("2019-07-18 12:00:06")}, (double)1.0},
{{std::string("2019-07-18 12:00:08")}, (double)2.0},
{{std::string("2019-07-18 12:00:10")}, (double)2.0}};
- check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
+ check_function<DataTypeTime, true>(func_name, input_types, data_set);
}
{
InputTypeSet input_types = {TypeIndex::Date};
diff --git a/be/test/vec/function/function_test_util.h b/be/test/vec/function/function_test_util.h
index 2603d224a4..6083278d29 100644
--- a/be/test/vec/function/function_test_util.h
+++ b/be/test/vec/function/function_test_util.h
@@ -35,6 +35,7 @@
#include "vec/data_types/data_type_jsonb.h"
#include "vec/data_types/data_type_number.h"
#include "vec/data_types/data_type_string.h"
+#include "vec/data_types/data_type_time.h"
#include "vec/functions/simple_function_factory.h"
namespace doris::vectorized {
@@ -233,7 +234,8 @@ Status check_function(const std::string& func_name, const InputTypeSet& input_ty
fn_ctx_return.type = doris_udf::FunctionContext::TYPE_BOOLEAN;
} else if constexpr (std::is_same_v<ReturnType, DataTypeInt32>) {
fn_ctx_return.type = doris_udf::FunctionContext::TYPE_INT;
- } else if constexpr (std::is_same_v<ReturnType, DataTypeFloat64>) {
+ } else if constexpr (std::is_same_v<ReturnType, DataTypeFloat64> ||
+ std::is_same_v<ReturnType, DataTypeTime>) {
fn_ctx_return.type = doris_udf::FunctionContext::TYPE_DOUBLE;
} else if constexpr (std::is_same_v<ReturnType, DateTime>) {
fn_ctx_return.type = doris_udf::FunctionContext::TYPE_DATETIME;
@@ -293,7 +295,8 @@ Status check_function(const std::string& func_name, const InputTypeSet& input_ty
const auto& column_data = field.get<DecimalField<Decimal128>>().get_value();
EXPECT_EQ(expect_data.value, column_data.value) << " at row " << i;
} else if constexpr (std::is_same_v<ReturnType, DataTypeFloat32> ||
- std::is_same_v<ReturnType, DataTypeFloat64>) {
+ std::is_same_v<ReturnType, DataTypeFloat64> ||
+ std::is_same_v<ReturnType, DataTypeTime>) {
const auto& column_data = field.get<DataTypeFloat64::FieldType>();
EXPECT_DOUBLE_EQ(expect_data, column_data) << " at row " << i;
} else {
diff --git a/be/test/vec/function/function_time_test.cpp b/be/test/vec/function/function_time_test.cpp
index 18455a4566..8a68a608b8 100644
--- a/be/test/vec/function/function_time_test.cpp
+++ b/be/test/vec/function/function_time_test.cpp
@@ -186,7 +186,7 @@ TEST(VTimestampFunctionsTest, timediff_test) {
{{std::string("2019-00-18 12:00:00"), std::string("2019-07-18 13:01:02")}, Null()},
{{std::string("2019-07-18 12:00:00"), std::string("2019-07-00 13:01:02")}, Null()}};
- check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
+ check_function<DataTypeTime, true>(func_name, input_types, data_set);
}
TEST(VTimestampFunctionsTest, date_format_test) {
@@ -849,7 +849,7 @@ TEST(VTimestampFunctionsTest, timediff_v2_test) {
{{std::string("2019-00-18"), std::string("2019-07-18")}, Null()},
{{std::string("2019-07-18"), std::string("2019-07-00")}, Null()}};
- check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
+ check_function<DataTypeTime, true>(func_name, input_types, data_set);
}
{
@@ -860,7 +860,7 @@ TEST(VTimestampFunctionsTest, timediff_v2_test) {
{{std::string("2019-00-18"), std::string("2019-07-18")}, Null()},
{{std::string("2019-07-18"), std::string("2019-07-00")}, Null()}};
- check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
+ check_function<DataTypeTime, true>(func_name, input_types, data_set);
}
{
@@ -871,7 +871,7 @@ TEST(VTimestampFunctionsTest, timediff_v2_test) {
{{std::string("2019-00-18"), std::string("2019-07-18")}, Null()},
{{std::string("2019-07-18"), std::string("2019-07-00")}, Null()}};
- check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
+ check_function<DataTypeTime, true>(func_name, input_types, data_set);
}
{
@@ -883,7 +883,7 @@ TEST(VTimestampFunctionsTest, timediff_v2_test) {
{{std::string("2019-00-18 00:00:00"), std::string("2019-07-18")}, Null()},
{{std::string("2019-07-18 00:00:00"), std::string("2019-07-00")}, Null()}};
- check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
+ check_function<DataTypeTime, true>(func_name, input_types, data_set);
}
{
@@ -895,7 +895,7 @@ TEST(VTimestampFunctionsTest, timediff_v2_test) {
{{std::string("2019-00-18"), std::string("2019-07-18 00:00:00")}, Null()},
{{std::string("2019-07-18"), std::string("2019-07-00 00:00:00")}, Null()}};
- check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
+ check_function<DataTypeTime, true>(func_name, input_types, data_set);
}
{
InputTypeSet input_types = {TypeIndex::DateTimeV2, TypeIndex::DateTimeV2};
@@ -906,7 +906,7 @@ TEST(VTimestampFunctionsTest, timediff_v2_test) {
{{std::string("2019-00-18 00:00:00"), std::string("2019-07-18 00:00:00")}, Null()},
{{std::string("2019-07-18 00:00:00"), std::string("2019-07-00 00:00:00")}, Null()}};
- check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
+ check_function<DataTypeTime, true>(func_name, input_types, data_set);
}
{
@@ -918,7 +918,7 @@ TEST(VTimestampFunctionsTest, timediff_v2_test) {
{{std::string("2019-00-18 00:00:00"), std::string("2019-07-18")}, Null()},
{{std::string("2019-07-18 00:00:00"), std::string("2019-07-00")}, Null()}};
- check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
+ check_function<DataTypeTime, true>(func_name, input_types, data_set);
}
{
@@ -930,7 +930,7 @@ TEST(VTimestampFunctionsTest, timediff_v2_test) {
{{std::string("2019-00-18"), std::string("2019-07-18 00:00:00")}, Null()},
{{std::string("2019-07-18"), std::string("2019-07-00 00:00:00")}, Null()}};
- check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
+ check_function<DataTypeTime, true>(func_name, input_types, data_set);
}
{
@@ -942,7 +942,7 @@ TEST(VTimestampFunctionsTest, timediff_v2_test) {
{{std::string("2019-00-18 00:00:00"), std::string("2019-07-18 00:00:00")}, Null()},
{{std::string("2019-07-18 00:00:00"), std::string("2019-07-00 00:00:00")}, Null()}};
- check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
+ check_function<DataTypeTime, true>(func_name, input_types, data_set);
}
{
@@ -957,7 +957,7 @@ TEST(VTimestampFunctionsTest, timediff_v2_test) {
{{std::string("2019-07-18 00:00:00.123"), std::string("2019-07-00 00:00:00")},
Null()}};
- check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
+ check_function<DataTypeTime, true>(func_name, input_types, data_set);
}
}
diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/ScalarType.java b/fe/fe-common/src/main/java/org/apache/doris/catalog/ScalarType.java
index c4364f2854..d643c24c1a 100644
--- a/fe/fe-common/src/main/java/org/apache/doris/catalog/ScalarType.java
+++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/ScalarType.java
@@ -1042,7 +1042,7 @@ public class ScalarType extends Type {
if (t1.isDatetimeV2() && t2.isDatetimeV2()) {
return t1.scale > t2.scale ? t1 : t2;
}
- if ((t1.isDatetimeV2() || t1.isDateV2()) && (t1.isDatetimeV2() || t1.isDateV2())) {
+ if ((t1.isDatetimeV2() || t1.isDateV2()) && (t2.isDatetimeV2() || t2.isDateV2())) {
return t1.isDatetimeV2() ? t1 : t2;
}
if (strict) {
@@ -1065,6 +1065,12 @@ public class ScalarType extends Type {
targetPrecision, targetScale);
}
+ public static ScalarType getAssignmentCompatibleDecimalV3Type(ScalarType t1, ScalarType t2) {
+ int targetPrecision = Math.max(t1.decimalPrecision(), t2.decimalPrecision());
+ int targetScale = Math.max(t1.decimalScale(), t2.decimalScale());
+ return ScalarType.createDecimalV3Type(targetPrecision, targetScale);
+ }
+
/**
* Returns true t1 can be implicitly cast to t2, false otherwise.
* If strict is true, only consider casts that result in no loss of precision.
diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java
index 184bafb313..f6b3f00742 100644
--- a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java
+++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java
@@ -393,6 +393,10 @@ public abstract class Type {
|| isScalarType(PrimitiveType.DATEV2) || isScalarType(PrimitiveType.DATETIMEV2);
}
+ public boolean isDateOrDateTime() {
+ return isScalarType(PrimitiveType.DATE) || isScalarType(PrimitiveType.DATETIME);
+ }
+
public boolean isDatetime() {
return isScalarType(PrimitiveType.DATETIME);
}
@@ -762,13 +766,17 @@ public abstract class Type {
type = ScalarType.createVarcharType(scalarType.getLen());
} else if (scalarType.getType() == TPrimitiveType.HLL) {
type = ScalarType.createHllType();
- } else if (scalarType.getType() == TPrimitiveType.DECIMALV2
- || scalarType.getType() == TPrimitiveType.DECIMAL32
+ } else if (scalarType.getType() == TPrimitiveType.DECIMALV2) {
+ Preconditions.checkState(scalarType.isSetPrecision()
+ && scalarType.isSetPrecision());
+ type = ScalarType.createDecimalType(scalarType.getPrecision(),
+ scalarType.getScale());
+ } else if (scalarType.getType() == TPrimitiveType.DECIMAL32
|| scalarType.getType() == TPrimitiveType.DECIMAL64
|| scalarType.getType() == TPrimitiveType.DECIMAL128I) {
Preconditions.checkState(scalarType.isSetPrecision()
&& scalarType.isSetScale());
- type = ScalarType.createDecimalType(scalarType.getPrecision(),
+ type = ScalarType.createDecimalV3Type(scalarType.getPrecision(),
scalarType.getScale());
} else if (scalarType.getType() == TPrimitiveType.DATETIMEV2) {
Preconditions.checkState(scalarType.isSetPrecision()
@@ -1498,17 +1506,15 @@ public abstract class Type {
case DATE:
case DATEV2:
case DATETIME:
+ case DATETIMEV2:
case TIME:
+ case TIMEV2:
case CHAR:
case VARCHAR:
case HLL:
case BITMAP:
case QUANTILE_STATE:
return VARCHAR;
- case DATETIMEV2:
- return DEFAULT_DATETIMEV2;
- case TIMEV2:
- return DEFAULT_TIMEV2;
case DECIMALV2:
return DECIMALV2;
case DECIMAL32:
@@ -1660,6 +1666,7 @@ public abstract class Type {
case DATE:
case DATEV2:
case DATETIME:
+ case DATETIMEV2:
return Type.BIGINT;
case LARGEINT:
return Type.LARGEINT;
@@ -1671,8 +1678,6 @@ public abstract class Type {
case STRING:
case HLL:
return Type.DOUBLE;
- case DATETIMEV2:
- return Type.DEFAULT_DATETIMEV2;
case TIMEV2:
return Type.DEFAULT_TIMEV2;
case DECIMALV2:
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java
index e5502fe635..02ea5579cb 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java
@@ -21,6 +21,7 @@
package org.apache.doris.analysis;
import org.apache.doris.catalog.Function;
+import org.apache.doris.catalog.Function.NullableMode;
import org.apache.doris.catalog.FunctionSet;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.ScalarFunction;
@@ -107,12 +108,13 @@ public class ArithmeticExpr extends Expr {
public static void initBuiltins(FunctionSet functionSet) {
for (Type t : Type.getNumericTypes()) {
+ NullableMode mode = t.isDecimalV3() ? NullableMode.CUSTOM : NullableMode.DEPEND_ON_ARGUMENT;
functionSet.addBuiltin(ScalarFunction.createBuiltinOperator(
- Operator.MULTIPLY.getName(), Lists.newArrayList(t, t), t));
+ Operator.MULTIPLY.getName(), Lists.newArrayList(t, t), t, mode));
functionSet.addBuiltin(ScalarFunction.createBuiltinOperator(
- Operator.ADD.getName(), Lists.newArrayList(t, t), t));
+ Operator.ADD.getName(), Lists.newArrayList(t, t), t, mode));
functionSet.addBuiltin(ScalarFunction.createBuiltinOperator(
- Operator.SUBTRACT.getName(), Lists.newArrayList(t, t), t));
+ Operator.SUBTRACT.getName(), Lists.newArrayList(t, t), t, mode));
}
functionSet.addBuiltin(ScalarFunction.createBuiltinOperator(
Operator.DIVIDE.getName(),
@@ -173,15 +175,14 @@ public class ArithmeticExpr extends Expr {
for (int j = 0; j < Type.getNumericTypes().size(); j++) {
Type t2 = Type.getNumericTypes().get(j);
+ Type retType = Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false));
+ NullableMode mode = retType.isDecimalV3() ? NullableMode.CUSTOM : NullableMode.DEPEND_ON_ARGUMENT;
functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator(
- Operator.MULTIPLY.getName(), Lists.newArrayList(t1, t2),
- Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false))));
+ Operator.MULTIPLY.getName(), Lists.newArrayList(t1, t2), retType, mode));
functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator(
- Operator.ADD.getName(), Lists.newArrayList(t1, t2),
- Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false))));
+ Operator.ADD.getName(), Lists.newArrayList(t1, t2), retType, mode));
functionSet.addBuiltin(ScalarFunction.createVecBuiltinOperator(
- Operator.SUBTRACT.getName(), Lists.newArrayList(t1, t2),
- Type.getNextNumType(Type.getAssignmentCompatibleType(t1, t2, false))));
+ Operator.SUBTRACT.getName(), Lists.newArrayList(t1, t2), retType, mode));
}
}
@@ -532,6 +533,9 @@ public class ArithmeticExpr extends Expr {
// max(scale1, scale2))
scale = Math.max(t1Scale, t2Scale);
precision = Math.max(widthOfIntPart1, widthOfIntPart2) + scale;
+ } else {
+ scale = Math.max(t1Scale, t2Scale);
+ precision = widthOfIntPart2 + scale;
}
if (precision > ScalarType.MAX_DECIMAL128_PRECISION) {
// TODO(gabriel): if precision is bigger than 38?
@@ -556,6 +560,9 @@ public class ArithmeticExpr extends Expr {
break;
}
castChild(ScalarType.createDecimalV3Type(precision, targetScale), 0);
+ } else if (op == Operator.MOD) {
+ castChild(type, 0);
+ castChild(type, 1);
}
break;
case INT_DIVIDE:
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java
index e1f24c4530..fed13dad19 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/BinaryPredicate.java
@@ -372,9 +372,6 @@ public class BinaryPredicate extends Predicate implements Writable {
if (t1 == PrimitiveType.BIGINT && t2 == PrimitiveType.BIGINT) {
return Type.getAssignmentCompatibleType(getChild(0).getType(), getChild(1).getType(), false);
}
- if (t1.isDecimalV3Type() || t2.isDecimalV3Type()) {
- return Type.getAssignmentCompatibleType(getChild(0).getType(), getChild(1).getType(), false);
- }
if ((t1 == PrimitiveType.BIGINT || t1 == PrimitiveType.DECIMALV2)
&& (t2 == PrimitiveType.BIGINT || t2 == PrimitiveType.DECIMALV2)) {
return Type.DECIMALV2;
@@ -400,6 +397,11 @@ public class BinaryPredicate extends Predicate implements Writable {
}
}
+ if ((t1.isDecimalV3Type() && !t2.isStringType() && !t2.isFloatingPointType())
+ || (t2.isDecimalV3Type() && !t1.isStringType() && !t1.isFloatingPointType())) {
+ return Type.getAssignmentCompatibleType(getChild(0).getType(), getChild(1).getType(), false);
+ }
+
return Type.DOUBLE;
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/DateLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/DateLiteral.java
index 229769e6f2..82e80a766f 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/DateLiteral.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/DateLiteral.java
@@ -1647,6 +1647,17 @@ public class DateLiteral extends LiteralExpr {
type = ScalarType.getDefaultDateType(Type.DATE);
} else {
type = ScalarType.getDefaultDateType(Type.DATETIME);
+ if (type.isDatetimeV2() && microsecond != 0) {
+ int scale = 6;
+ for (int i = 0; i < 6; i++) {
+ if (microsecond % Math.pow(10.0, i + 1) > 0) {
+ break;
+ } else {
+ scale -= 1;
+ }
+ }
+ type = ScalarType.createDatetimeV2Type(scale);
+ }
}
if (checkRange() || checkDate()) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
index 6329d51d62..1dc9740aae 100755
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
@@ -20,6 +20,7 @@
package org.apache.doris.analysis;
+import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Function;
import org.apache.doris.catalog.FunctionSet;
@@ -31,6 +32,7 @@ import org.apache.doris.common.Config;
import org.apache.doris.common.TreeNode;
import org.apache.doris.common.io.Writable;
import org.apache.doris.common.util.VectorizedUtil;
+import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.ExprStats;
import org.apache.doris.thrift.TExpr;
import org.apache.doris.thrift.TExprNode;
@@ -2039,6 +2041,25 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
if (fn.functionName().equalsIgnoreCase("concat_ws")) {
return children.get(0).isNullable();
}
+ if (fn.functionName().equalsIgnoreCase(Operator.MULTIPLY.getName())
+ && fn.getReturnType().isDecimalV3()) {
+ if (ConnectContext.get() != null
+ && ConnectContext.get().getSessionVariable().checkOverflowForDecimal()) {
+ return true;
+ } else {
+ return hasNullableChild();
+ }
+ }
+ if ((fn.functionName().equalsIgnoreCase(Operator.ADD.getName())
+ || fn.functionName().equalsIgnoreCase(Operator.SUBTRACT.getName()))
+ && fn.getReturnType().isDecimalV3()) {
+ if (ConnectContext.get() != null
+ && ConnectContext.get().getSessionVariable().checkOverflowForDecimal()) {
+ return true;
+ } else {
+ return hasNullableChild();
+ }
+ }
return true;
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FloatLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FloatLiteral.java
index f9fd62dedc..e15a30824c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FloatLiteral.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FloatLiteral.java
@@ -196,8 +196,8 @@ public class FloatLiteral extends LiteralExpr {
return res;
} else if (targetType.isDecimalV3()) {
DecimalLiteral res = new DecimalLiteral(new BigDecimal(value));
- res.setType(ScalarType.createDecimalV3Type(res.getType().getPrecision(),
- ((ScalarType) res.getType()).decimalScale()));
+ res.setType(ScalarType.createDecimalV3Type(targetType.getPrecision(),
+ ((ScalarType) targetType).decimalScale()));
return res;
}
return this;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
index d9e8dad2fe..6ee34851ac 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
@@ -1314,8 +1314,11 @@ public class FunctionCallExpr extends Expr {
&& argTypes[i].isDecimalV3() && args[ix].isDecimalV2()) {
uncheckedCastChild(ScalarType.createDecimalV3Type(argTypes[i].getPrecision(),
((ScalarType) argTypes[i]).getScalarScale()), i);
+ } else if (fnName.getFunction().equalsIgnoreCase("money_format")
+ && children.get(0).getType().isDecimalV3() && args[ix].isDecimalV3()) {
+ continue;
} else if (!argTypes[i].matchesType(args[ix]) && !(
- argTypes[i].isDateType() && args[ix].isDateType())
+ argTypes[i].isDateOrDateTime() && args[ix].isDateOrDateTime())
&& (!fn.getReturnType().isDecimalV3()
|| (argTypes[i].isValid() && !argTypes[i].isDecimalV3() && args[ix].isDecimalV3()))) {
uncheckedCastChild(args[ix], i);
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/SetOperationStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/SetOperationStmt.java
index b599d2cb98..984f2c5822 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/SetOperationStmt.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/SetOperationStmt.java
@@ -492,6 +492,12 @@ public class SetOperationStmt extends QueryStmt {
(ScalarType) selectTypeWithNullable.get(j).first,
(ScalarType) operands.get(i).getQueryStmt().getResultExprs().get(j).getType());
}
+ if (selectTypeWithNullable.get(j).first.isDecimalV3()
+ && operands.get(i).getQueryStmt().getResultExprs().get(j).getType().isDecimalV3()) {
+ selectTypeWithNullable.get(j).first = ScalarType.getAssignmentCompatibleDecimalV3Type(
+ (ScalarType) selectTypeWithNullable.get(j).first,
+ (ScalarType) operands.get(i).getQueryStmt().getResultExprs().get(j).getType());
+ }
}
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/StringLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/StringLiteral.java
index 9fa1690e44..1094ca6564 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/StringLiteral.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/StringLiteral.java
@@ -21,7 +21,6 @@
package org.apache.doris.analysis;
import org.apache.doris.catalog.PrimitiveType;
-import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.DdlException;
@@ -182,11 +181,11 @@ public class StringLiteral extends LiteralExpr {
public LiteralExpr convertToDate(Type targetType) throws AnalysisException {
LiteralExpr newLiteral = null;
try {
- newLiteral = new DateLiteral(value, ScalarType.getDefaultDateType(targetType));
+ newLiteral = new DateLiteral(value, targetType);
} catch (AnalysisException e) {
if (targetType.isScalarType(PrimitiveType.DATETIME)) {
- newLiteral = new DateLiteral(value, ScalarType.getDefaultDateType(Type.DATE));
- newLiteral.setType(ScalarType.getDefaultDateType(Type.DATETIME));
+ newLiteral = new DateLiteral(value, Type.DATE);
+ newLiteral.setType(Type.DATETIME);
} else if (targetType.isScalarType(PrimitiveType.DATETIMEV2)) {
newLiteral = new DateLiteral(value, Type.DATEV2);
newLiteral.setType(targetType);
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/TimestampArithmeticExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/TimestampArithmeticExpr.java
index 2fc7850abc..ca2c098f34 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/TimestampArithmeticExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/TimestampArithmeticExpr.java
@@ -285,11 +285,8 @@ public class TimestampArithmeticExpr extends Expr {
for (int i = 0; i < childrenTypes.length; ++i) {
// For varargs, we must compare with the last type in callArgs.argTypes.
int ix = Math.min(argTypes.length - 1, i);
- if (!childrenTypes[i].matchesType(argTypes[ix]) && Config.enable_date_conversion
- && !childrenTypes[i].isDateType() && (argTypes[ix].isDate() || argTypes[ix].isDatetime())) {
- uncheckedCastChild(ScalarType.getDefaultDateType(argTypes[ix]), i);
- } else if (!childrenTypes[i].matchesType(argTypes[ix]) && !(
- childrenTypes[i].isDateType() && argTypes[ix].isDateType())) {
+ if (!childrenTypes[i].matchesType(argTypes[ix]) && !(
+ childrenTypes[i].isDateOrDateTime() && argTypes[ix].isDateOrDateTime())) {
uncheckedCastChild(argTypes[ix], i);
}
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/ScanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/ScanNode.java
index 460ea01b7b..ac368176ed 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/ScanNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/ScanNode.java
@@ -105,7 +105,10 @@ public abstract class ScanNode extends PlanNode {
protected Expr castToSlot(SlotDescriptor slotDesc, Expr expr) throws UserException {
PrimitiveType dstType = slotDesc.getType().getPrimitiveType();
PrimitiveType srcType = expr.getType().getPrimitiveType();
- if (dstType != srcType) {
+ if (PrimitiveType.typeWithPrecision.contains(dstType) && PrimitiveType.typeWithPrecision.contains(srcType)
+ && !slotDesc.getType().equals(expr.getType())) {
+ return expr.castTo(slotDesc.getType());
+ } else if (dstType != srcType) {
return expr.castTo(slotDesc.getType());
} else {
return expr;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
index 863cd9f639..74c5412705 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
@@ -182,6 +182,8 @@ public class SessionVariable implements Serializable, Writable {
public static final String ENABLE_PROJECTION = "enable_projection";
+ public static final String CHECK_OVERFLOW_FOR_DECIMAL = "check_overflow_for_decimal";
+
public static final String TRIM_TAILING_SPACES_FOR_EXTERNAL_TABLE_QUERY
= "trim_tailing_spaces_for_external_table_query";
@@ -535,6 +537,9 @@ public class SessionVariable implements Serializable, Writable {
@VariableMgr.VarAttr(name = ENABLE_PROJECTION)
private boolean enableProjection = true;
+ @VariableMgr.VarAttr(name = CHECK_OVERFLOW_FOR_DECIMAL)
+ private boolean checkOverflowForDecimal = false;
+
/**
* as the new optimizer is not mature yet, use this var
* to control whether to use new optimizer, remove it when
@@ -1194,6 +1199,10 @@ public class SessionVariable implements Serializable, Writable {
return enableProjection;
}
+ public boolean checkOverflowForDecimal() {
+ return checkOverflowForDecimal;
+ }
+
public boolean isTrimTailingSpacesForExternalTableQuery() {
return trimTailingSpacesForExternalTableQuery;
}
@@ -1329,6 +1338,7 @@ public class SessionVariable implements Serializable, Writable {
}
tResult.setEnableFunctionPushdown(enableFunctionPushdown);
+ tResult.setCheckOverflowForDecimal(checkOverflowForDecimal);
tResult.setFragmentTransmissionCompressionCodec(fragmentTransmissionCompressionCodec);
tResult.setEnableLocalExchange(enableLocalExchange);
tResult.setEnableNewShuffleHashMethod(enableNewShuffleHashMethod);
diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RoundLiteralInBinaryPredicatesRule.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RoundLiteralInBinaryPredicatesRule.java
index 93ab19a516..8ef6ef7a37 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RoundLiteralInBinaryPredicatesRule.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RoundLiteralInBinaryPredicatesRule.java
@@ -24,6 +24,7 @@ import org.apache.doris.analysis.BoolLiteral;
import org.apache.doris.analysis.DateLiteral;
import org.apache.doris.analysis.DecimalLiteral;
import org.apache.doris.analysis.Expr;
+import org.apache.doris.analysis.IsNullPredicate;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.common.AnalysisException;
@@ -64,7 +65,7 @@ public class RoundLiteralInBinaryPredicatesRule implements ExprRewriteRule {
expr.setChild(1, literal);
return expr;
} else {
- return new BoolLiteral(true);
+ return new IsNullPredicate(expr0, true);
}
}
case GT:
@@ -117,7 +118,7 @@ public class RoundLiteralInBinaryPredicatesRule implements ExprRewriteRule {
expr.setChild(1, literal);
return expr;
} else {
- return new BoolLiteral(true);
+ return new IsNullPredicate(expr0, true);
}
}
case GT:
diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/CreateTableAsSelectStmtTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/CreateTableAsSelectStmtTest.java
index 06892a1042..35ea3e303a 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/analysis/CreateTableAsSelectStmtTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/CreateTableAsSelectStmtTest.java
@@ -95,12 +95,19 @@ public class CreateTableAsSelectStmtTest extends TestWithFeService {
createTableAsSelect(selectFromDecimal1);
if (Config.enable_decimal_conversion) {
Assertions.assertEquals(
- "CREATE TABLE `select_decimal_table_1` (\n" + " `_col0` decimal(38, 2) NULL\n" + ") ENGINE=OLAP\n"
- + "DUPLICATE KEY(`_col0`)\n" + "COMMENT 'OLAP'\n"
- + "DISTRIBUTED BY HASH(`_col0`) BUCKETS 10\n" + "PROPERTIES (\n"
+ "CREATE TABLE `select_decimal_table_1` (\n"
+ + " `_col0` decimal(38, 2) NULL\n"
+ + ") ENGINE=OLAP\n"
+ + "DUPLICATE KEY(`_col0`)\n"
+ + "COMMENT 'OLAP'\n"
+ + "DISTRIBUTED BY HASH(`_col0`) BUCKETS 10\n"
+ + "PROPERTIES (\n"
+ "\"replication_allocation\" = \"tag.location.default: 1\",\n"
- + "\"in_memory\" = \"false\",\n" + "\"storage_format\" = \"V2\","
- + "\n\"disable_auto_compaction\" = \"false\"\n" + ");",
+ + "\"in_memory\" = \"false\",\n"
+ + "\"storage_format\" = \"V2\",\n"
+ + "\"light_schema_change\" = \"true\",\n"
+ + "\"disable_auto_compaction\" = \"false\"\n"
+ + ");",
showCreateTableByName("select_decimal_table_1").getResultRows().get(0).get(1));
} else {
Assertions.assertEquals(
@@ -302,7 +309,8 @@ public class CreateTableAsSelectStmtTest extends TestWithFeService {
createTableAsSelect(createSql);
ShowResultSet showResultSet = showCreateTableByName("test_default_timestamp");
Assertions.assertEquals("CREATE TABLE `test_default_timestamp` (\n" + " `userId` varchar(65533) NOT NULL,\n"
- + " `date` datetime NULL DEFAULT CURRENT_TIMESTAMP\n"
+ + " `date` " + (Config.enable_date_conversion ? "datetimev2(0)" : "datetime")
+ + " NULL DEFAULT CURRENT_TIMESTAMP\n"
+ ") ENGINE=OLAP\n" + "DUPLICATE KEY(`userId`)\n"
+ "COMMENT 'OLAP'\n" + "DISTRIBUTED BY HASH(`userId`) BUCKETS 10\n" + "PROPERTIES (\n"
+ "\"replication_allocation\" = \"tag.location.default: 1\",\n" + "\"in_memory\" = \"false\",\n"
diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/QueryStmtTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/QueryStmtTest.java
index e0b8a2f9b0..130c23d427 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/analysis/QueryStmtTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/QueryStmtTest.java
@@ -190,11 +190,7 @@ public class QueryStmtTest {
Assert.assertEquals(2, exprsMap.size());
constMap.clear();
constMap = getConstantExprMap(exprsMap, analyzer);
- if (Config.enable_decimal_conversion) {
- Assert.assertEquals(6, constMap.size());
- } else {
- Assert.assertEquals(0, constMap.size());
- }
+ Assert.assertEquals(0, constMap.size());
sql = "SELECT k1 FROM db1.baseall GROUP BY k1 HAVING EXISTS(SELECT k4 FROM db1.tbl1 GROUP BY k4 "
+ "HAVING SUM(k4) = k4);";
diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
index 39dc947fe1..c4f975462f 100755
--- a/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
@@ -498,12 +498,12 @@ public class SelectStmtTest {
Assert.assertTrue(dorisAssert
.query(sql3)
.explainQuery()
- .contains("`dt` = '2020-09-08 00:00:00'"));
+ .contains(Config.enable_date_conversion ? "`dt` = '2020-09-08'" : "`dt` = '2020-09-08 00:00:00'"));
String sql4 = "select count() from db1.date_partition_table where dt='2020-09-08'";
Assert.assertTrue(dorisAssert
.query(sql4)
.explainQuery()
- .contains("`dt` = '2020-09-08 00:00:00'"));
+ .contains(Config.enable_date_conversion ? "`dt` = '2020-09-08'" : "`dt` = '2020-09-08 00:00:00'"));
}
@Test
diff --git a/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java b/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java
index 5b403f95b4..115cdfcc9a 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java
@@ -1604,11 +1604,14 @@ public class QueryPlanTest extends TestWithFeService {
//valid date
String sql = "select day from tbl_int_date where day in ('2020-10-30')";
String explainString = getSQLPlanOrErrorMsg("EXPLAIN " + sql);
- Assert.assertTrue(explainString.contains("PREDICATES: `day` IN ('2020-10-30 00:00:00')"));
+ Assert.assertTrue(explainString.contains(Config.enable_date_conversion ? "PREDICATES: `day` IN ('2020-10-30')"
+ : "PREDICATES: `day` IN ('2020-10-30 00:00:00')"));
//valid date
sql = "select day from tbl_int_date where day in ('2020-10-30','2020-10-29')";
explainString = getSQLPlanOrErrorMsg("EXPLAIN " + sql);
- Assert.assertTrue(explainString.contains("PREDICATES: `day` IN ('2020-10-30 00:00:00', '2020-10-29 00:00:00')"));
+ Assert.assertTrue(explainString.contains(Config.enable_date_conversion
+ ? "PREDICATES: `day` IN ('2020-10-30', '2020-10-29')"
+ : "PREDICATES: `day` IN ('2020-10-30 00:00:00', '2020-10-29 00:00:00')"));
//valid datetime
sql = "select day from tbl_int_date where date in ('2020-10-30 12:12:30')";
@@ -1678,7 +1681,8 @@ public class QueryPlanTest extends TestWithFeService {
//valid date
String sql = "select day from tbl_int_date where day = '2020-10-30'";
String explainString = getSQLPlanOrErrorMsg("EXPLAIN " + sql);
- Assert.assertTrue(explainString.contains("PREDICATES: `day` = '2020-10-30 00:00:00'"));
+ Assert.assertTrue(explainString.contains(Config.enable_date_conversion ? "PREDICATES: `day` = '2020-10-30'"
+ : "PREDICATES: `day` = '2020-10-30 00:00:00'"));
sql = "select day from tbl_int_date where day = from_unixtime(1196440219)";
explainString = getSQLPlanOrErrorMsg("EXPLAIN " + sql);
Assert.assertTrue(explainString.contains("PREDICATES: `day` = '2007-12-01 00:30:19'"));
@@ -1688,19 +1692,19 @@ public class QueryPlanTest extends TestWithFeService {
//valid date
sql = "select day from tbl_int_date where day = 20201030";
explainString = getSQLPlanOrErrorMsg("EXPLAIN " + sql);
- Assert.assertTrue(explainString.contains("PREDICATES: `day` = '2020-10-30 00:00:00'"));
+ Assert.assertTrue(explainString.contains(Config.enable_date_conversion ? "PREDICATES: `day` = '2020-10-30'"
+ : "PREDICATES: `day` = '2020-10-30 00:00:00'"));
//valid date
sql = "select day from tbl_int_date where day = '20201030'";
explainString = getSQLPlanOrErrorMsg("EXPLAIN " + sql);
- Assert.assertTrue(explainString.contains("PREDICATES: `day` = '2020-10-30 00:00:00'"));
+ Assert.assertTrue(explainString.contains(Config.enable_date_conversion ? "PREDICATES: `day` = '2020-10-30'"
+ : "PREDICATES: `day` = '2020-10-30 00:00:00'"));
//valid date contains micro second
sql = "select day from tbl_int_date where day = '2020-10-30 10:00:01.111111'";
explainString = getSQLPlanOrErrorMsg("EXPLAIN " + sql);
- if (Config.enable_date_conversion) {
- Assert.assertTrue(explainString.contains("PREDICATES: `day` = '2020-10-30 10:00:01.111111'"));
- } else {
- Assert.assertTrue(explainString.contains("PREDICATES: `day` = '2020-10-30 10:00:01'"));
- }
+ Assert.assertTrue(explainString.contains(Config.enable_date_conversion
+ ? "PREDICATES: `day` = '2020-10-30 10:00:01.111111'"
+ : "PREDICATES: `day` = '2020-10-30 10:00:01'"));
//invalid date
sql = "select day from tbl_int_date where day = '2020-10-32'";
@@ -1754,11 +1758,8 @@ public class QueryPlanTest extends TestWithFeService {
//valid datetime contains micro second
sql = "select day from tbl_int_date where date = '2020-10-30 10:00:01.111111'";
explainString = getSQLPlanOrErrorMsg("EXPLAIN " + sql);
- if (Config.enable_date_conversion) {
- Assert.assertTrue(explainString.contains("PREDICATES: `date` = '2020-10-30 10:00:01.111111'"));
- } else {
- Assert.assertTrue(explainString.contains("PREDICATES: `date` = '2020-10-30 10:00:01'"));
- }
+ Assert.assertTrue(explainString.contains(Config.enable_date_conversion
+ ? "VEMPTYSET" : "PREDICATES: `date` = '2020-10-30 10:00:01'"));
//invalid datetime
sql = "select day from tbl_int_date where date = '2020-10-32'";
explainString = getSQLPlanOrErrorMsg("EXPLAIN " + sql);
@@ -1890,8 +1891,13 @@ public class QueryPlanTest extends TestWithFeService {
+ " \"line_delimiter\" = \"\\n\","
+ " \"max_file_size\" = \"500MB\" );";
String explainStr = getSQLPlanOrErrorMsg("EXPLAIN " + sql);
- Assert.assertTrue(explainStr.contains("PREDICATES: `date` >= '2021-10-07 00:00:00',"
- + " `date` <= '2021-10-11 00:00:00'"));
+ if (Config.enable_date_conversion) {
+ Assert.assertTrue(explainStr.contains("PREDICATES: `date` >= '2021-10-07',"
+ + " `date` <= '2021-10-11'"));
+ } else {
+ Assert.assertTrue(explainStr.contains("PREDICATES: `date` >= '2021-10-07 00:00:00',"
+ + " `date` <= '2021-10-11 00:00:00'"));
+ }
}
// Fix: issue-#7929
diff --git a/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteDateLiteralRuleTest.java b/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteDateLiteralRuleTest.java
index c3c252561f..40bc9bb2c3 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteDateLiteralRuleTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteDateLiteralRuleTest.java
@@ -18,7 +18,6 @@
package org.apache.doris.rewrite;
import org.apache.doris.common.AnalysisException;
-import org.apache.doris.common.Config;
import org.apache.doris.common.FeConstants;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.utframe.DorisAssert;
@@ -77,11 +76,7 @@ public class RewriteDateLiteralRuleTest {
public void testWithStringFormatDate() throws Exception {
String query = "select * from " + DB_NAME + ".tb1 where k1 > '2021030112334455'";
String planString = dorisAssert.query(query).explainQuery();
- if (Config.enable_date_conversion) {
- Assert.assertTrue(planString.contains("`k1` > '2021-03-01 12:33:44.550000'"));
- } else {
- Assert.assertTrue(planString.contains("`k1` > '2021-03-01 12:33:44'"));
- }
+ Assert.assertTrue(planString.contains("`k1` > '2021-03-01 12:33:44'"));
query = "select k1 > '20210301' from " + DB_NAME + ".tb1";
planString = dorisAssert.query(query).explainQuery();
@@ -89,11 +84,7 @@ public class RewriteDateLiteralRuleTest {
query = "select k1 > '20210301233234.34' from " + DB_NAME + ".tb1";
planString = dorisAssert.query(query).explainQuery();
- if (Config.enable_date_conversion) {
- Assert.assertTrue(planString.contains("`k1` > '2021-03-01 23:32:34.340000'"));
- } else {
- Assert.assertTrue(planString.contains("`k1` > '2021-03-01 23:32:34'"));
- }
+ Assert.assertTrue(planString.contains("`k1` > '2021-03-01 23:32:34'"));
query = "select * from " + DB_NAME + ".tb1 where k1 > '2021-03-01'";
planString = dorisAssert.query(query).explainQuery();
@@ -177,37 +168,21 @@ public class RewriteDateLiteralRuleTest {
public void testWithDoubleFormatDate() throws Exception {
String query = "select * from " + DB_NAME + ".tb1 where k1 > 20210301.22";
String planString = dorisAssert.query(query).explainQuery();
- if (Config.enable_decimal_conversion) {
- Assert.assertTrue(planString.contains("`k1` > 20210301"));
- } else {
- Assert.assertTrue(planString.contains("`k1` > 2.021030122E7"));
- }
+ Assert.assertTrue(planString.contains("`k1` > 2.021030122E7"));
query = "select k1 > 20210331.22 from " + DB_NAME + ".tb1";
planString = dorisAssert.query(query).explainQuery();
- if (Config.enable_decimal_conversion) {
- Assert.assertTrue(planString.contains("`k1` > 20210331"));
- } else {
- Assert.assertTrue(planString.contains("`k1` > 2.021033122E7"));
- }
+ Assert.assertTrue(planString.contains("`k1` > 2.021033122E7"));
}
public void testWithDoubleFormatDateV2() throws Exception {
String query = "select * from " + DB_NAME + ".tb2 where k1 > 20210301.22";
String planString = dorisAssert.query(query).explainQuery();
- if (Config.enable_decimal_conversion) {
- Assert.assertTrue(planString.contains("`k1` > 20210301"));
- } else {
- Assert.assertTrue(planString.contains("`k1` > 2.021030122E7"));
- }
+ Assert.assertTrue(planString.contains("`k1` > 2.021030122E7"));
query = "select k1 > 20210331.22 from " + DB_NAME + ".tb2";
planString = dorisAssert.query(query).explainQuery();
- if (Config.enable_decimal_conversion) {
- Assert.assertTrue(planString.contains("`k1` > 20210331"));
- } else {
- Assert.assertTrue(planString.contains("`k1` > 2.021033122E7"));
- }
+ Assert.assertTrue(planString.contains("`k1` > 2.021033122E7"));
}
public void testWithInvalidFormatDate() throws Exception {
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java
similarity index 53%
copy from fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
copy to fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java
index 2b3070eab9..55ff08f700 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java
@@ -18,78 +18,67 @@
package org.apache.doris.udf;
import org.apache.doris.catalog.Type;
-import org.apache.doris.common.Pair;
import org.apache.doris.thrift.TJavaUdfExecutorCtorParams;
import org.apache.doris.udf.UdfUtils.JavaUdfDataType;
-import com.google.common.base.Joiner;
-import com.google.common.collect.Lists;
+import com.google.common.base.Preconditions;
import org.apache.log4j.Logger;
import org.apache.thrift.TDeserializer;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import java.io.IOException;
-import java.lang.reflect.Constructor;
-import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.math.BigInteger;
-import java.net.MalformedURLException;
+import java.math.RoundingMode;
import java.net.URLClassLoader;
import java.nio.charset.StandardCharsets;
-import java.util.ArrayList;
import java.util.Arrays;
-public class UdfExecutor {
- private static final Logger LOG = Logger.getLogger(UdfExecutor.class);
+public abstract class BaseExecutor {
+ private static final Logger LOG = Logger.getLogger(BaseExecutor.class);
// By convention, the function in the class must be called evaluate()
public static final String UDF_FUNCTION_NAME = "evaluate";
+ public static final String UDAF_CREATE_FUNCTION = "create";
+ public static final String UDAF_DESTROY_FUNCTION = "destroy";
+ public static final String UDAF_ADD_FUNCTION = "add";
+ public static final String UDAF_SERIALIZE_FUNCTION = "serialize";
+ public static final String UDAF_DESERIALIZE_FUNCTION = "deserialize";
+ public static final String UDAF_MERGE_FUNCTION = "merge";
+ public static final String UDAF_RESULT_FUNCTION = "getValue";
// Object to deserialize ctor params from BE.
- private static final TBinaryProtocol.Factory PROTOCOL_FACTORY =
+ protected static final TBinaryProtocol.Factory PROTOCOL_FACTORY =
new TBinaryProtocol.Factory();
- private Object udf;
+ protected Object udf;
// setup by init() and cleared by close()
- private Method method;
- // setup by init() and cleared by close()
- private URLClassLoader classLoader;
+ protected URLClassLoader classLoader;
// Return and argument types of the function inferred from the udf method signature.
// The JavaUdfDataType enum maps it to corresponding primitive type.
- private JavaUdfDataType[] argTypes;
- private JavaUdfDataType retType;
+ protected JavaUdfDataType[] argTypes;
+ protected JavaUdfDataType retType;
// Input buffer from the backend. This is valid for the duration of an evaluate() call.
// These buffers are allocated in the BE.
- private final long inputBufferPtrs;
- private final long inputNullsPtrs;
- private final long inputOffsetsPtrs;
+ protected final long inputBufferPtrs;
+ protected final long inputNullsPtrs;
+ protected final long inputOffsetsPtrs;
// Output buffer to return non-string values. These buffers are allocated in the BE.
- private final long outputBufferPtr;
- private final long outputNullPtr;
- private final long outputOffsetsPtr;
- private final long outputIntermediateStatePtr;
-
- // Pre-constructed input objects for the UDF. This minimizes object creation overhead
- // as these objects are reused across calls to evaluate().
- private Object[] inputObjects;
- // inputArgs_[i] is either inputObjects[i] or null
- private Object[] inputArgs;
-
- private long outputOffset;
- private long rowIdx;
-
- private final long batchSizePtr;
- private Class[] argClass;
+ protected final long outputBufferPtr;
+ protected final long outputNullPtr;
+ protected final long outputOffsetsPtr;
+ protected final long outputIntermediateStatePtr;
+ protected Class[] argClass;
/**
* Create a UdfExecutor, using parameters from a serialized thrift object. Used by
* the backend.
*/
- public UdfExecutor(byte[] thriftParams) throws Exception {
+ public BaseExecutor(byte[] thriftParams) throws Exception {
TJavaUdfExecutorCtorParams request = new TJavaUdfExecutorCtorParams();
TDeserializer deserializer = new TDeserializer(PROTOCOL_FACTORY);
try {
@@ -97,14 +86,6 @@ public class UdfExecutor {
} catch (TException e) {
throw new InternalException(e.getMessage());
}
- String className = request.fn.scalar_fn.symbol;
- String jarFile = request.location;
- Type retType = UdfUtils.fromThrift(request.fn.ret_type, 0).first;
- Type[] parameterTypes = new Type[request.fn.arg_types.size()];
- for (int i = 0; i < request.fn.arg_types.size(); ++i) {
- parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i));
- }
- batchSizePtr = request.batch_size_ptr;
inputBufferPtrs = request.input_buffer_ptrs;
inputNullsPtrs = request.input_nulls_ptrs;
inputOffsetsPtrs = request.input_offsets_ptrs;
@@ -114,18 +95,139 @@ public class UdfExecutor {
outputOffsetsPtr = request.output_offsets_ptr;
outputIntermediateStatePtr = request.output_intermediate_state_ptr;
- outputOffset = 0L;
- rowIdx = 0L;
+ Type[] parameterTypes = new Type[request.fn.arg_types.size()];
+ for (int i = 0; i < request.fn.arg_types.size(); ++i) {
+ parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i));
+ }
+ String jarFile = request.location;
+ Type funcRetType = UdfUtils.fromThrift(request.fn.ret_type, 0).first;
- init(jarFile, className, retType, parameterTypes);
+ init(request, jarFile, funcRetType, parameterTypes);
}
- @Override
- protected void finalize() throws Throwable {
- close();
- super.finalize();
+ protected abstract void init(TJavaUdfExecutorCtorParams request, String jarPath,
+ Type funcRetType, Type... parameterTypes) throws UdfRuntimeException;
+
+ protected Object[] allocateInputObjects(long row, int argClassOffset) throws UdfRuntimeException {
+ Object[] inputObjects = new Object[argTypes.length];
+
+ for (int i = 0; i < argTypes.length; ++i) {
+ if (UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) != -1
+ && (UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null,
+ UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + row) == 1)) {
+ inputObjects[i] = null;
+ continue;
+ }
+ switch (argTypes[i]) {
+ case BOOLEAN:
+ inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null,
+ UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
+ break;
+ case TINYINT:
+ inputObjects[i] = UdfUtils.UNSAFE.getByte(null,
+ UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
+ break;
+ case SMALLINT:
+ inputObjects[i] = UdfUtils.UNSAFE.getShort(null,
+ UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ break;
+ case INT:
+ inputObjects[i] = UdfUtils.UNSAFE.getInt(null,
+ UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ break;
+ case BIGINT:
+ inputObjects[i] = UdfUtils.UNSAFE.getLong(null,
+ UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ break;
+ case FLOAT:
+ inputObjects[i] = UdfUtils.UNSAFE.getFloat(null,
+ UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ break;
+ case DOUBLE:
+ inputObjects[i] = UdfUtils.UNSAFE.getDouble(null,
+ UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ break;
+ case DATE: {
+ long data = UdfUtils.UNSAFE.getLong(null,
+ UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ inputObjects[i] = UdfUtils.convertDateToJavaDate(data, argClass[i + argClassOffset]);
+ break;
+ }
+ case DATETIME: {
+ long data = UdfUtils.UNSAFE.getLong(null,
+ UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ inputObjects[i] = UdfUtils.convertDateTimeToJavaDateTime(data, argClass[i + argClassOffset]);
+ break;
+ }
+ case DATEV2: {
+ int data = UdfUtils.UNSAFE.getInt(null,
+ UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ inputObjects[i] = UdfUtils.convertDateV2ToJavaDate(data, argClass[i + argClassOffset]);
+ break;
+ }
+ case DATETIMEV2: {
+ long data = UdfUtils.UNSAFE.getLong(null,
+ UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row);
+ inputObjects[i] = UdfUtils.convertDateTimeV2ToJavaDateTime(data, argClass[i + argClassOffset]);
+ break;
+ }
+ case LARGEINT: {
+ long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row;
+ byte[] bytes = new byte[argTypes[i].getLen()];
+ UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen());
+
+ inputObjects[i] = new BigInteger(UdfUtils.convertByteOrder(bytes));
+ break;
+ }
+ case DECIMALV2:
+ case DECIMAL32:
+ case DECIMAL64:
+ case DECIMAL128: {
+ long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + argTypes[i].getLen() * row;
+ byte[] bytes = new byte[argTypes[i].getLen()];
+ UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen());
+
+ BigInteger value = new BigInteger(UdfUtils.convertByteOrder(bytes));
+ inputObjects[i] = new BigDecimal(value, argTypes[i].getScale());
+ break;
+ }
+ case CHAR:
+ case VARCHAR:
+ case STRING: {
+ long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null,
+ UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * row));
+ long numBytes = row == 0 ? offset : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
+ UdfUtils.UNSAFE.getLong(null,
+ UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1)));
+ long base =
+ row == 0 ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) :
+ UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
+ + offset - numBytes;
+ byte[] bytes = new byte[(int) numBytes];
+ UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes);
+ inputObjects[i] = new String(bytes, StandardCharsets.UTF_8);
+ break;
+ }
+ default:
+ throw new UdfRuntimeException("Unsupported argument type: " + argTypes[i]);
+ }
+ }
+ return inputObjects;
}
+ protected abstract long getCurrentOutputOffset(long row);
+
/**
* Close the class loader we may have created.
*/
@@ -140,91 +242,11 @@ public class UdfExecutor {
}
// We are now un-usable (because the class loader has been
// closed), so null out method_ and classLoader_.
- method = null;
classLoader = null;
}
- /**
- * evaluate function called by the backend. The inputs to the UDF have
- * been serialized to 'input'
- */
- public void evaluate() throws UdfRuntimeException {
- int batchSize = UdfUtils.UNSAFE.getInt(null, batchSizePtr);
- try {
- if (retType.equals(JavaUdfDataType.STRING) || retType.equals(JavaUdfDataType.VARCHAR)
- || retType.equals(JavaUdfDataType.CHAR)) {
- // If this udf return variable-size type (e.g.) String, we have to allocate output
- // buffer multiple times until buffer size is enough to store output column. So we
- // always begin with the last evaluated row instead of beginning of this batch.
- rowIdx = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr + 8);
- if (rowIdx == 0) {
- outputOffset = 0L;
- }
- } else {
- rowIdx = 0;
- }
- for (; rowIdx < batchSize; rowIdx++) {
- allocateInputObjects(rowIdx);
- for (int i = 0; i < argTypes.length; ++i) {
- // Currently, -1 indicates this column is not nullable. So input argument is
- // null iff inputNullsPtrs_ != -1 and nullCol[row_idx] != 0.
- if (UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) == -1
- || UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + rowIdx) == 0) {
- inputArgs[i] = inputObjects[i];
- } else {
- inputArgs[i] = null;
- }
- }
- // `storeUdfResult` is called to store udf result to output column. If true
- // is returned, current value is stored successfully. Otherwise, current result is
- // not processed successfully (e.g. current output buffer is not large enough) so
- // we break this loop directly.
- if (!storeUdfResult(evaluate(inputArgs), rowIdx)) {
- UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, rowIdx);
- return;
- }
- }
- } catch (Exception e) {
- if (retType.equals(JavaUdfDataType.STRING)) {
- UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, batchSize);
- }
- throw new UdfRuntimeException("UDF::evaluate() ran into a problem.", e);
- }
- if (retType.equals(JavaUdfDataType.STRING)) {
- UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, rowIdx);
- }
- }
-
- /**
- * Evaluates the UDF with 'args' as the input to the UDF.
- */
- private Object evaluate(Object... args) throws UdfRuntimeException {
- try {
- return method.invoke(udf, args);
- } catch (Exception e) {
- throw new UdfRuntimeException("UDF failed to evaluate", e);
- }
- }
-
- public Method getMethod() {
- return method;
- }
-
// Sets the result object 'obj' into the outputBufferPtr and outputNullPtr_
- private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException {
- if (obj == null) {
- if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) {
- throw new UdfRuntimeException("UDF failed to store null data to not null column");
- }
- UdfUtils.UNSAFE.putByte(null, UdfUtils.UNSAFE.getLong(null, outputNullPtr) + row, (byte) 1);
- if (retType.equals(JavaUdfDataType.STRING)) {
- UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr)
- + 4L * row, Integer.parseUnsignedInt(String.valueOf(outputOffset)));
- }
- return true;
- }
+ protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException {
if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) != -1) {
UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputNullPtr) + row, (byte) 0);
}
@@ -266,22 +288,22 @@ public class UdfExecutor {
return true;
}
case DATE: {
- long time = UdfUtils.convertToDate(obj, method.getReturnType());
+ long time = UdfUtils.convertToDate(obj, retClass);
UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time);
return true;
}
case DATETIME: {
- long time = UdfUtils.convertToDateTime(obj, method.getReturnType());
+ long time = UdfUtils.convertToDateTime(obj, retClass);
UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time);
return true;
}
case DATEV2: {
- int time = UdfUtils.convertToDateV2(obj, method.getReturnType());
+ int time = UdfUtils.convertToDateV2(obj, retClass);
UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time);
return true;
}
case DATETIMEV2: {
- long time = UdfUtils.convertToDateTimeV2(obj, method.getReturnType());
+ long time = UdfUtils.convertToDateTimeV2(obj, retClass);
UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time);
return true;
}
@@ -305,6 +327,7 @@ public class UdfExecutor {
return true;
}
case DECIMALV2: {
+ Preconditions.checkArgument(((BigDecimal) obj).scale() == 9, "Scale of DECIMALV2 must be 9");
BigInteger data = ((BigDecimal) obj).unscaledValue();
byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray());
//TODO: here is maybe overflow also, and may find a better way to handle
@@ -321,19 +344,41 @@ public class UdfExecutor {
UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length);
return true;
}
+ case DECIMAL32:
+ case DECIMAL64:
+ case DECIMAL128: {
+ BigDecimal retValue = ((BigDecimal) obj).setScale(retType.getScale(), RoundingMode.HALF_EVEN);
+ BigInteger data = retValue.unscaledValue();
+ byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray());
+ //TODO: here is maybe overflow also, and may find a better way to handle
+ byte[] value = new byte[retType.getLen()];
+ if (data.signum() == -1) {
+ Arrays.fill(value, (byte) -1);
+ }
+
+ for (int index = 0; index < Math.min(bytes.length, value.length); ++index) {
+ value[index] = bytes[index];
+ }
+
+ UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null,
+ UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length);
+ return true;
+ }
case CHAR:
case VARCHAR:
case STRING: {
long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr);
byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8);
- if (outputOffset + bytes.length > bufferSize) {
+ long offset = getCurrentOutputOffset(row);
+ if (offset + bytes.length > bufferSize) {
return false;
}
- outputOffset += bytes.length;
+ offset += bytes.length;
UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * row,
- Integer.parseUnsignedInt(String.valueOf(outputOffset)));
+ Integer.parseUnsignedInt(String.valueOf(offset)));
UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + outputOffset - bytes.length, bytes.length);
+ UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + offset - bytes.length, bytes.length);
+ updateOutputOffset(offset);
return true;
}
default:
@@ -341,191 +386,5 @@ public class UdfExecutor {
}
}
- // Preallocate the input objects that will be passed to the underlying UDF.
- // These objects are allocated once and reused across calls to evaluate()
- private void allocateInputObjects(long row) throws UdfRuntimeException {
- inputObjects = new Object[argTypes.length];
- inputArgs = new Object[argTypes.length];
-
- for (int i = 0; i < argTypes.length; ++i) {
- switch (argTypes[i]) {
- case BOOLEAN:
- inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
- break;
- case TINYINT:
- inputObjects[i] = UdfUtils.UNSAFE.getByte(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
- break;
- case SMALLINT:
- inputObjects[i] = UdfUtils.UNSAFE.getShort(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 2L * row);
- break;
- case INT:
- inputObjects[i] = UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 4L * row);
- break;
- case BIGINT:
- inputObjects[i] = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row);
- break;
- case FLOAT:
- inputObjects[i] = UdfUtils.UNSAFE.getFloat(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 4L * row);
- break;
- case DOUBLE:
- inputObjects[i] = UdfUtils.UNSAFE.getDouble(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row);
- break;
- case DATE: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row);
- inputObjects[i] = UdfUtils.convertDateToJavaDate(data, argClass[i]);
- break;
- }
- case DATETIME: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row);
- inputObjects[i] = UdfUtils.convertDateTimeToJavaDateTime(data, argClass[i]);
- break;
- }
- case DATEV2: {
- int data = UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 4L * row);
- inputObjects[i] = UdfUtils.convertDateV2ToJavaDate(data, argClass[i]);
- break;
- }
- case DATETIMEV2: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row);
- inputObjects[i] = UdfUtils.convertDateTimeV2ToJavaDateTime(data, argClass[i]);
- break;
- }
- case LARGEINT: {
- long base =
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 16L * row;
- byte[] bytes = new byte[16];
- UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16);
-
- inputObjects[i] = new BigInteger(UdfUtils.convertByteOrder(bytes));
- break;
- }
- case DECIMALV2: {
- long base =
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 16L * row;
- byte[] bytes = new byte[16];
- UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16);
-
- BigInteger value = new BigInteger(UdfUtils.convertByteOrder(bytes));
- inputObjects[i] = new BigDecimal(value, 9);
- break;
- }
- case CHAR:
- case VARCHAR:
- case STRING: {
- long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * row));
- long numBytes = row == 0 ? offset : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1)));
- long base =
- row == 0 ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) :
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + offset - numBytes;
- byte[] bytes = new byte[(int) numBytes];
- UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes);
- inputObjects[i] = new String(bytes, StandardCharsets.UTF_8);
- break;
- }
- default:
- throw new UdfRuntimeException("Unsupported argument type: " + argTypes[i]);
- }
- }
- }
-
- private void init(String jarPath, String udfPath, Type funcRetType, Type... parameterTypes)
- throws UdfRuntimeException {
- ArrayList<String> signatures = Lists.newArrayList();
- try {
- LOG.debug("Loading UDF '" + udfPath + "' from " + jarPath);
- ClassLoader loader;
- if (jarPath != null) {
- // Save for cleanup.
- ClassLoader parent = getClass().getClassLoader();
- classLoader = UdfUtils.getClassLoader(jarPath, parent);
- loader = classLoader;
- } else {
- // for test
- loader = ClassLoader.getSystemClassLoader();
- }
- Class<?> c = Class.forName(udfPath, true, loader);
- Constructor<?> ctor = c.getConstructor();
- udf = ctor.newInstance();
- Method[] methods = c.getMethods();
- for (Method m : methods) {
- // By convention, the udf must contain the function "evaluate"
- if (!m.getName().equals(UDF_FUNCTION_NAME)) {
- continue;
- }
- signatures.add(m.toGenericString());
- argClass = m.getParameterTypes();
-
- // Try to match the arguments
- if (argClass.length != parameterTypes.length) {
- continue;
- }
- method = m;
- Pair<Boolean, JavaUdfDataType> returnType;
- if (argClass.length == 0 && parameterTypes.length == 0) {
- // Special case where the UDF doesn't take any input args
- returnType = UdfUtils.setReturnType(funcRetType, m.getReturnType());
- if (!returnType.first) {
- continue;
- } else {
- retType = returnType.second;
- }
- argTypes = new JavaUdfDataType[0];
- LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath);
- return;
- }
- returnType = UdfUtils.setReturnType(funcRetType, m.getReturnType());
- if (!returnType.first) {
- continue;
- } else {
- retType = returnType.second;
- }
- Pair<Boolean, JavaUdfDataType[]> inputType = UdfUtils.setArgTypes(parameterTypes, argClass, false);
- if (!inputType.first) {
- continue;
- } else {
- argTypes = inputType.second;
- }
- LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath);
- return;
- }
-
- StringBuilder sb = new StringBuilder();
- sb.append("Unable to find evaluate function with the correct signature: ")
- .append(udfPath + ".evaluate(")
- .append(Joiner.on(", ").join(parameterTypes))
- .append(")\n")
- .append("UDF contains: \n ")
- .append(Joiner.on("\n ").join(signatures));
- throw new UdfRuntimeException(sb.toString());
- } catch (MalformedURLException e) {
- throw new UdfRuntimeException("Unable to load jar.", e);
- } catch (SecurityException e) {
- throw new UdfRuntimeException("Unable to load function.", e);
- } catch (ClassNotFoundException e) {
- throw new UdfRuntimeException("Unable to find class.", e);
- } catch (NoSuchMethodException e) {
- throw new UdfRuntimeException(
- "Unable to find constructor with no arguments.", e);
- } catch (IllegalArgumentException e) {
- throw new UdfRuntimeException(
- "Unable to call UDF constructor with no arguments.", e);
- } catch (Exception e) {
- throw new UdfRuntimeException("Unable to call create UDF instance.", e);
- }
- }
+ protected void updateOutputOffset(long offset) {}
}
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
index f684363bf6..4f88fa967e 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
@@ -25,9 +25,6 @@ import org.apache.doris.udf.UdfUtils.JavaUdfDataType;
import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import org.apache.log4j.Logger;
-import org.apache.thrift.TDeserializer;
-import org.apache.thrift.TException;
-import org.apache.thrift.protocol.TBinaryProtocol;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
@@ -35,99 +32,36 @@ import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
-import java.math.BigDecimal;
-import java.math.BigInteger;
import java.net.MalformedURLException;
-import java.net.URLClassLoader;
-import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.HashMap;
/**
* udaf executor.
*/
-public class UdafExecutor {
- public static final String UDAF_CREATE_FUNCTION = "create";
- public static final String UDAF_DESTROY_FUNCTION = "destroy";
- public static final String UDAF_ADD_FUNCTION = "add";
- public static final String UDAF_SERIALIZE_FUNCTION = "serialize";
- public static final String UDAF_DESERIALIZE_FUNCTION = "deserialize";
- public static final String UDAF_MERGE_FUNCTION = "merge";
- public static final String UDAF_RESULT_FUNCTION = "getValue";
+public class UdafExecutor extends BaseExecutor {
+
private static final Logger LOG = Logger.getLogger(UdafExecutor.class);
- private static final TBinaryProtocol.Factory PROTOCOL_FACTORY = new TBinaryProtocol.Factory();
- private final long inputBufferPtrs;
- private final long inputNullsPtrs;
- private final long inputOffsetsPtrs;
- private final long inputPlacesPtr;
- private final long outputBufferPtr;
- private final long outputNullPtr;
- private final long outputOffsetsPtr;
- private final long outputIntermediateStatePtr;
- private Object udaf;
+
+ private long inputPlacesPtr;
private HashMap<String, Method> allMethods;
private HashMap<Long, Object> stateObjMap;
- private URLClassLoader classLoader;
- private JavaUdfDataType[] argTypes;
- private JavaUdfDataType retType;
- private Class[] argClass;
private Class retClass;
/**
* Constructor to create an object.
*/
public UdafExecutor(byte[] thriftParams) throws Exception {
- TJavaUdfExecutorCtorParams request = new TJavaUdfExecutorCtorParams();
- TDeserializer deserializer = new TDeserializer(PROTOCOL_FACTORY);
- try {
- deserializer.deserialize(request, thriftParams);
- } catch (TException e) {
- throw new InternalException(e.getMessage());
- }
- Type[] parameterTypes = new Type[request.fn.arg_types.size()];
- for (int i = 0; i < request.fn.arg_types.size(); ++i) {
- parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i));
- }
- inputBufferPtrs = request.input_buffer_ptrs;
- inputNullsPtrs = request.input_nulls_ptrs;
- inputOffsetsPtrs = request.input_offsets_ptrs;
- inputPlacesPtr = request.input_places_ptr;
-
- outputBufferPtr = request.output_buffer_ptr;
- outputNullPtr = request.output_null_ptr;
- outputOffsetsPtr = request.output_offsets_ptr;
- outputIntermediateStatePtr = request.output_intermediate_state_ptr;
- allMethods = new HashMap<>();
- stateObjMap = new HashMap<>();
- String className = request.fn.aggregate_fn.symbol;
- String jarFile = request.location;
- Type funcRetType = UdfUtils.fromThrift(request.fn.ret_type, 0).first;
- init(jarFile, className, funcRetType, parameterTypes);
+ super(thriftParams);
}
/**
* close and invoke destroy function.
*/
+ @Override
public void close() {
- if (classLoader != null) {
- try {
- classLoader.close();
- } catch (Exception e) {
- // Log and ignore.
- LOG.debug("Error closing the URLClassloader.", e);
- }
- }
- // We are now un-usable (because the class loader has been
- // closed), so null out allMethods and classLoader.
allMethods = null;
- classLoader = null;
- }
-
- @Override
- protected void finalize() throws Throwable {
- close();
- super.finalize();
+ super.close();
}
/**
@@ -142,11 +76,11 @@ public class UdafExecutor {
stateObjMap.putIfAbsent(curPlace, createAggState());
inputArgs[0] = stateObjMap.get(curPlace);
do {
- Object[] inputObjects = allocateInputObjects(idx);
+ Object[] inputObjects = allocateInputObjects(idx, 1);
for (int i = 0; i < argTypes.length; ++i) {
inputArgs[i + 1] = inputObjects[i];
}
- allMethods.get(UDAF_ADD_FUNCTION).invoke(udaf, inputArgs);
+ allMethods.get(UDAF_ADD_FUNCTION).invoke(udf, inputArgs);
idx++;
} while (isSinglePlace && idx < rowEnd);
} while (idx < rowEnd);
@@ -160,7 +94,7 @@ public class UdafExecutor {
*/
public Object createAggState() throws UdfRuntimeException {
try {
- return allMethods.get(UDAF_CREATE_FUNCTION).invoke(udaf, null);
+ return allMethods.get(UDAF_CREATE_FUNCTION).invoke(udf, null);
} catch (Exception e) {
throw new UdfRuntimeException("UDAF failed to create: ", e);
}
@@ -172,7 +106,7 @@ public class UdafExecutor {
public void destroy() throws UdfRuntimeException {
try {
for (Object obj : stateObjMap.values()) {
- allMethods.get(UDAF_DESTROY_FUNCTION).invoke(udaf, obj);
+ allMethods.get(UDAF_DESTROY_FUNCTION).invoke(udf, obj);
}
stateObjMap.clear();
} catch (Exception e) {
@@ -189,7 +123,7 @@ public class UdafExecutor {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
args[0] = stateObjMap.get((Long) place);
args[1] = new DataOutputStream(baos);
- allMethods.get(UDAF_SERIALIZE_FUNCTION).invoke(udaf, args);
+ allMethods.get(UDAF_SERIALIZE_FUNCTION).invoke(udf, args);
return baos.toByteArray();
} catch (Exception e) {
throw new UdfRuntimeException("UDAF failed to serialize: ", e);
@@ -206,12 +140,12 @@ public class UdafExecutor {
ByteArrayInputStream bins = new ByteArrayInputStream(data);
args[0] = createAggState();
args[1] = new DataInputStream(bins);
- allMethods.get(UDAF_DESERIALIZE_FUNCTION).invoke(udaf, args);
+ allMethods.get(UDAF_DESERIALIZE_FUNCTION).invoke(udf, args);
args[1] = args[0];
Long curPlace = place;
stateObjMap.putIfAbsent(curPlace, createAggState());
args[0] = stateObjMap.get(curPlace);
- allMethods.get(UDAF_MERGE_FUNCTION).invoke(udaf, args);
+ allMethods.get(UDAF_MERGE_FUNCTION).invoke(udf, args);
} catch (Exception e) {
throw new UdfRuntimeException("UDAF failed to merge: ", e);
}
@@ -222,14 +156,15 @@ public class UdafExecutor {
*/
public boolean getValue(long row, long place) throws UdfRuntimeException {
try {
- return storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udaf, stateObjMap.get((Long) place)),
- row);
+ return storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf, stateObjMap.get((Long) place)),
+ row, retClass);
} catch (Exception e) {
throw new UdfRuntimeException("UDAF failed to result", e);
}
}
- private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException {
+ @Override
+ protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException {
if (obj == null) {
// If result is null, return true directly when row == 0 as we have already inserted default value.
if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) {
@@ -237,234 +172,23 @@ public class UdafExecutor {
}
return true;
}
- if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) != -1) {
- UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputNullPtr) + row, (byte) 0);
- }
- switch (retType) {
- case BOOLEAN: {
- boolean val = (boolean) obj;
- UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(),
- val ? (byte) 1 : 0);
- return true;
- }
- case TINYINT: {
- UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(),
- (byte) obj);
- return true;
- }
- case SMALLINT: {
- UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(),
- (short) obj);
- return true;
- }
- case INT: {
- UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(),
- (int) obj);
- return true;
- }
- case BIGINT: {
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(),
- (long) obj);
- return true;
- }
- case FLOAT: {
- UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(),
- (float) obj);
- return true;
- }
- case DOUBLE: {
- UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(),
- (double) obj);
- return true;
- }
- case DATE: {
- long time = UdfUtils.convertToDate(obj, retClass);
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case DATETIME: {
- long time = UdfUtils.convertToDateTime(obj, retClass);
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case DATEV2: {
- long time = UdfUtils.convertToDateV2(obj, retClass);
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case DATETIMEV2: {
- long time = UdfUtils.convertToDateTimeV2(obj, retClass);
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case LARGEINT: {
- BigInteger data = (BigInteger) obj;
- byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray());
-
- //here value is 16 bytes, so if result data greater than the maximum of 16 bytes
- //it will return a wrong num to backend;
- byte[] value = new byte[16];
- //check data is negative
- if (data.signum() == -1) {
- Arrays.fill(value, (byte) -1);
- }
- for (int index = 0; index < Math.min(bytes.length, value.length); ++index) {
- value[index] = bytes[index];
- }
-
- UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length);
- return true;
- }
- case DECIMALV2: {
- BigInteger data = ((BigDecimal) obj).unscaledValue();
- byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray());
- //TODO: here is maybe overflow also, and may find a better way to handle
- byte[] value = new byte[16];
- if (data.signum() == -1) {
- Arrays.fill(value, (byte) -1);
- }
-
- for (int index = 0; index < Math.min(bytes.length, value.length); ++index) {
- value[index] = bytes[index];
- }
-
- UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length);
- return true;
- }
- case CHAR:
- case VARCHAR:
- case STRING: {
- long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr);
- byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8);
- long offset = Integer.toUnsignedLong(
- UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * (row - 1)));
- if (offset + bytes.length > bufferSize) {
- return false;
- }
- offset += bytes.length;
- UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * row,
- Integer.parseUnsignedInt(String.valueOf(offset)));
- UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + offset - bytes.length, bytes.length);
- return true;
- }
- default:
- throw new UdfRuntimeException("Unsupported return type: " + retType);
- }
+ return super.storeUdfResult(obj, row, retClass);
}
- private Object[] allocateInputObjects(long row) throws UdfRuntimeException {
- Object[] inputObjects = new Object[argTypes.length];
-
- for (int i = 0; i < argTypes.length; ++i) {
- // skip the input column of current row is null
- if (UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) != -1
- && (UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + row) == 1)) {
- inputObjects[i] = null;
- continue;
- }
- switch (argTypes[i]) {
- case BOOLEAN:
- inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
- break;
- case TINYINT:
- inputObjects[i] = UdfUtils.UNSAFE.getByte(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
- break;
- case SMALLINT:
- inputObjects[i] = UdfUtils.UNSAFE.getShort(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 2L * row);
- break;
- case INT:
- inputObjects[i] = UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 4L * row);
- break;
- case BIGINT:
- inputObjects[i] = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row);
- break;
- case FLOAT:
- inputObjects[i] = UdfUtils.UNSAFE.getFloat(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 4L * row);
- break;
- case DOUBLE:
- inputObjects[i] = UdfUtils.UNSAFE.getDouble(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row);
- break;
- case DATE: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row);
- inputObjects[i] = UdfUtils.convertDateToJavaDate(data, argClass[i + 1]);
- break;
- }
- case DATETIME: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row);
- inputObjects[i] = UdfUtils.convertDateTimeToJavaDateTime(data, argClass[i + 1]);
- break;
- }
- case DATEV2: {
- int data = UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 4L * row);
- inputObjects[i] = UdfUtils.convertDateV2ToJavaDate(data, argClass[i + 1]);
- break;
- }
- case DATETIMEV2: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row);
- inputObjects[i] = UdfUtils.convertDateTimeV2ToJavaDateTime(data, argClass[i + 1]);
- break;
- }
- case LARGEINT: {
- long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + 16L * row;
- byte[] bytes = new byte[16];
- UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16);
-
- inputObjects[i] = new BigInteger(UdfUtils.convertByteOrder(bytes));
- break;
- }
- case DECIMALV2: {
- long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + 16L * row;
- byte[] bytes = new byte[16];
- UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16);
-
- BigInteger value = new BigInteger(UdfUtils.convertByteOrder(bytes));
- inputObjects[i] = new BigDecimal(value, 9);
- break;
- }
- case CHAR:
- case VARCHAR:
- case STRING: {
- long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i))
- + 4L * row));
- long numBytes = row == 0 ? offset : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row
- - 1)));
- long base = row == 0 ? UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- : UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + offset
- - numBytes;
- byte[] bytes = new byte[(int) numBytes];
- UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes);
- inputObjects[i] = new String(bytes, StandardCharsets.UTF_8);
- break;
- }
- default:
- throw new UdfRuntimeException("Unsupported argument type: " + argTypes[i]);
- }
- }
- return inputObjects;
+ @Override
+ protected long getCurrentOutputOffset(long row) {
+ return Integer.toUnsignedLong(
+ UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * (row - 1)));
}
- private void init(String jarPath, String udfPath, Type funcRetType, Type... parameterTypes)
- throws UdfRuntimeException {
+ @Override
+ protected void init(TJavaUdfExecutorCtorParams request, String jarPath, Type funcRetType,
+ Type... parameterTypes) throws UdfRuntimeException {
+ String className = request.fn.aggregate_fn.symbol;
+ inputPlacesPtr = request.input_places_ptr;
+ allMethods = new HashMap<>();
+ stateObjMap = new HashMap<>();
+
ArrayList<String> signatures = Lists.newArrayList();
try {
ClassLoader loader;
@@ -476,9 +200,9 @@ public class UdafExecutor {
// for test
loader = ClassLoader.getSystemClassLoader();
}
- Class<?> c = Class.forName(udfPath, true, loader);
+ Class<?> c = Class.forName(className, true, loader);
Constructor<?> ctor = c.getConstructor();
- udaf = ctor.newInstance();
+ udf = ctor.newInstance();
Method[] methods = c.getDeclaredMethods();
int idx = 0;
for (idx = 0; idx < methods.length; ++idx) {
@@ -534,7 +258,7 @@ public class UdafExecutor {
return;
}
StringBuilder sb = new StringBuilder();
- sb.append("Unable to find evaluate function with the correct signature: ").append(udfPath + ".evaluate(")
+ sb.append("Unable to find evaluate function with the correct signature: ").append(className + ".evaluate(")
.append(Joiner.on(", ").join(parameterTypes)).append(")\n").append("UDF contains: \n ")
.append(Joiner.on("\n ").join(signatures));
throw new UdfRuntimeException(sb.toString());
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
index 2b3070eab9..5f043f64a8 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java
@@ -25,123 +25,43 @@ import org.apache.doris.udf.UdfUtils.JavaUdfDataType;
import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import org.apache.log4j.Logger;
-import org.apache.thrift.TDeserializer;
-import org.apache.thrift.TException;
-import org.apache.thrift.protocol.TBinaryProtocol;
-import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
-import java.math.BigDecimal;
-import java.math.BigInteger;
import java.net.MalformedURLException;
-import java.net.URLClassLoader;
-import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
-import java.util.Arrays;
-public class UdfExecutor {
+public class UdfExecutor extends BaseExecutor {
private static final Logger LOG = Logger.getLogger(UdfExecutor.class);
-
- // By convention, the function in the class must be called evaluate()
- public static final String UDF_FUNCTION_NAME = "evaluate";
-
- // Object to deserialize ctor params from BE.
- private static final TBinaryProtocol.Factory PROTOCOL_FACTORY =
- new TBinaryProtocol.Factory();
-
- private Object udf;
// setup by init() and cleared by close()
private Method method;
- // setup by init() and cleared by close()
- private URLClassLoader classLoader;
-
- // Return and argument types of the function inferred from the udf method signature.
- // The JavaUdfDataType enum maps it to corresponding primitive type.
- private JavaUdfDataType[] argTypes;
- private JavaUdfDataType retType;
-
- // Input buffer from the backend. This is valid for the duration of an evaluate() call.
- // These buffers are allocated in the BE.
- private final long inputBufferPtrs;
- private final long inputNullsPtrs;
- private final long inputOffsetsPtrs;
-
- // Output buffer to return non-string values. These buffers are allocated in the BE.
- private final long outputBufferPtr;
- private final long outputNullPtr;
- private final long outputOffsetsPtr;
- private final long outputIntermediateStatePtr;
// Pre-constructed input objects for the UDF. This minimizes object creation overhead
// as these objects are reused across calls to evaluate().
private Object[] inputObjects;
- // inputArgs_[i] is either inputObjects[i] or null
- private Object[] inputArgs;
private long outputOffset;
private long rowIdx;
- private final long batchSizePtr;
- private Class[] argClass;
+ private long batchSizePtr;
/**
* Create a UdfExecutor, using parameters from a serialized thrift object. Used by
* the backend.
*/
public UdfExecutor(byte[] thriftParams) throws Exception {
- TJavaUdfExecutorCtorParams request = new TJavaUdfExecutorCtorParams();
- TDeserializer deserializer = new TDeserializer(PROTOCOL_FACTORY);
- try {
- deserializer.deserialize(request, thriftParams);
- } catch (TException e) {
- throw new InternalException(e.getMessage());
- }
- String className = request.fn.scalar_fn.symbol;
- String jarFile = request.location;
- Type retType = UdfUtils.fromThrift(request.fn.ret_type, 0).first;
- Type[] parameterTypes = new Type[request.fn.arg_types.size()];
- for (int i = 0; i < request.fn.arg_types.size(); ++i) {
- parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i));
- }
- batchSizePtr = request.batch_size_ptr;
- inputBufferPtrs = request.input_buffer_ptrs;
- inputNullsPtrs = request.input_nulls_ptrs;
- inputOffsetsPtrs = request.input_offsets_ptrs;
-
- outputBufferPtr = request.output_buffer_ptr;
- outputNullPtr = request.output_null_ptr;
- outputOffsetsPtr = request.output_offsets_ptr;
- outputIntermediateStatePtr = request.output_intermediate_state_ptr;
-
- outputOffset = 0L;
- rowIdx = 0L;
-
- init(jarFile, className, retType, parameterTypes);
- }
-
- @Override
- protected void finalize() throws Throwable {
- close();
- super.finalize();
+ super(thriftParams);
}
/**
* Close the class loader we may have created.
*/
+ @Override
public void close() {
- if (classLoader != null) {
- try {
- classLoader.close();
- } catch (IOException e) {
- // Log and ignore.
- LOG.debug("Error closing the URLClassloader.", e);
- }
- }
// We are now un-usable (because the class loader has been
// closed), so null out method_ and classLoader_.
method = null;
- classLoader = null;
+ super.close();
}
/**
@@ -164,24 +84,12 @@ public class UdfExecutor {
rowIdx = 0;
}
for (; rowIdx < batchSize; rowIdx++) {
- allocateInputObjects(rowIdx);
- for (int i = 0; i < argTypes.length; ++i) {
- // Currently, -1 indicates this column is not nullable. So input argument is
- // null iff inputNullsPtrs_ != -1 and nullCol[row_idx] != 0.
- if (UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) == -1
- || UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + rowIdx) == 0) {
- inputArgs[i] = inputObjects[i];
- } else {
- inputArgs[i] = null;
- }
- }
+ inputObjects = allocateInputObjects(rowIdx, 0);
// `storeUdfResult` is called to store udf result to output column. If true
// is returned, current value is stored successfully. Otherwise, current result is
// not processed successfully (e.g. current output buffer is not large enough) so
// we break this loop directly.
- if (!storeUdfResult(evaluate(inputArgs), rowIdx)) {
+ if (!storeUdfResult(evaluate(inputObjects), rowIdx, method.getReturnType())) {
UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, rowIdx);
return;
}
@@ -213,7 +121,8 @@ public class UdfExecutor {
}
// Sets the result object 'obj' into the outputBufferPtr and outputNullPtr_
- private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException {
+ @Override
+ protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException {
if (obj == null) {
if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) {
throw new UdfRuntimeException("UDF failed to store null data to not null column");
@@ -225,229 +134,31 @@ public class UdfExecutor {
}
return true;
}
- if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) != -1) {
- UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputNullPtr) + row, (byte) 0);
- }
- switch (retType) {
- case BOOLEAN: {
- boolean val = (boolean) obj;
- UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(),
- val ? (byte) 1 : 0);
- return true;
- }
- case TINYINT: {
- UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(),
- (byte) obj);
- return true;
- }
- case SMALLINT: {
- UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(),
- (short) obj);
- return true;
- }
- case INT: {
- UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(),
- (int) obj);
- return true;
- }
- case BIGINT: {
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(),
- (long) obj);
- return true;
- }
- case FLOAT: {
- UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(),
- (float) obj);
- return true;
- }
- case DOUBLE: {
- UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(),
- (double) obj);
- return true;
- }
- case DATE: {
- long time = UdfUtils.convertToDate(obj, method.getReturnType());
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case DATETIME: {
- long time = UdfUtils.convertToDateTime(obj, method.getReturnType());
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case DATEV2: {
- int time = UdfUtils.convertToDateV2(obj, method.getReturnType());
- UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case DATETIMEV2: {
- long time = UdfUtils.convertToDateTimeV2(obj, method.getReturnType());
- UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time);
- return true;
- }
- case LARGEINT: {
- BigInteger data = (BigInteger) obj;
- byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray());
-
- //here value is 16 bytes, so if result data greater than the maximum of 16 bytes
- //it will return a wrong num to backend;
- byte[] value = new byte[16];
- //check data is negative
- if (data.signum() == -1) {
- Arrays.fill(value, (byte) -1);
- }
- for (int index = 0; index < Math.min(bytes.length, value.length); ++index) {
- value[index] = bytes[index];
- }
-
- UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length);
- return true;
- }
- case DECIMALV2: {
- BigInteger data = ((BigDecimal) obj).unscaledValue();
- byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray());
- //TODO: here is maybe overflow also, and may find a better way to handle
- byte[] value = new byte[16];
- if (data.signum() == -1) {
- Arrays.fill(value, (byte) -1);
- }
+ return super.storeUdfResult(obj, row, retClass);
+ }
- for (int index = 0; index < Math.min(bytes.length, value.length); ++index) {
- value[index] = bytes[index];
- }
+ @Override
+ protected long getCurrentOutputOffset(long row) {
+ return outputOffset;
+ }
- UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length);
- return true;
- }
- case CHAR:
- case VARCHAR:
- case STRING: {
- long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr);
- byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8);
- if (outputOffset + bytes.length > bufferSize) {
- return false;
- }
- outputOffset += bytes.length;
- UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * row,
- Integer.parseUnsignedInt(String.valueOf(outputOffset)));
- UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null,
- UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + outputOffset - bytes.length, bytes.length);
- return true;
- }
- default:
- throw new UdfRuntimeException("Unsupported return type: " + retType);
- }
+ @Override
+ protected void updateOutputOffset(long offset) {
+ outputOffset = offset;
}
// Preallocate the input objects that will be passed to the underlying UDF.
// These objects are allocated once and reused across calls to evaluate()
- private void allocateInputObjects(long row) throws UdfRuntimeException {
- inputObjects = new Object[argTypes.length];
- inputArgs = new Object[argTypes.length];
-
- for (int i = 0; i < argTypes.length; ++i) {
- switch (argTypes[i]) {
- case BOOLEAN:
- inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
- break;
- case TINYINT:
- inputObjects[i] = UdfUtils.UNSAFE.getByte(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row);
- break;
- case SMALLINT:
- inputObjects[i] = UdfUtils.UNSAFE.getShort(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 2L * row);
- break;
- case INT:
- inputObjects[i] = UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 4L * row);
- break;
- case BIGINT:
- inputObjects[i] = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row);
- break;
- case FLOAT:
- inputObjects[i] = UdfUtils.UNSAFE.getFloat(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 4L * row);
- break;
- case DOUBLE:
- inputObjects[i] = UdfUtils.UNSAFE.getDouble(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row);
- break;
- case DATE: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row);
- inputObjects[i] = UdfUtils.convertDateToJavaDate(data, argClass[i]);
- break;
- }
- case DATETIME: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row);
- inputObjects[i] = UdfUtils.convertDateTimeToJavaDateTime(data, argClass[i]);
- break;
- }
- case DATEV2: {
- int data = UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 4L * row);
- inputObjects[i] = UdfUtils.convertDateV2ToJavaDate(data, argClass[i]);
- break;
- }
- case DATETIMEV2: {
- long data = UdfUtils.UNSAFE.getLong(null,
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row);
- inputObjects[i] = UdfUtils.convertDateTimeV2ToJavaDateTime(data, argClass[i]);
- break;
- }
- case LARGEINT: {
- long base =
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 16L * row;
- byte[] bytes = new byte[16];
- UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16);
-
- inputObjects[i] = new BigInteger(UdfUtils.convertByteOrder(bytes));
- break;
- }
- case DECIMALV2: {
- long base =
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 16L * row;
- byte[] bytes = new byte[16];
- UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16);
-
- BigInteger value = new BigInteger(UdfUtils.convertByteOrder(bytes));
- inputObjects[i] = new BigDecimal(value, 9);
- break;
- }
- case CHAR:
- case VARCHAR:
- case STRING: {
- long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * row));
- long numBytes = row == 0 ? offset : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null,
- UdfUtils.UNSAFE.getLong(null,
- UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1)));
- long base =
- row == 0 ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) :
- UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i))
- + offset - numBytes;
- byte[] bytes = new byte[(int) numBytes];
- UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes);
- inputObjects[i] = new String(bytes, StandardCharsets.UTF_8);
- break;
- }
- default:
- throw new UdfRuntimeException("Unsupported argument type: " + argTypes[i]);
- }
- }
- }
-
- private void init(String jarPath, String udfPath, Type funcRetType, Type... parameterTypes)
- throws UdfRuntimeException {
+ @Override
+ protected void init(TJavaUdfExecutorCtorParams request, String jarPath, Type funcRetType,
+ Type... parameterTypes) throws UdfRuntimeException {
+ String className = request.fn.scalar_fn.symbol;
+ batchSizePtr = request.batch_size_ptr;
+ outputOffset = 0L;
+ rowIdx = 0L;
ArrayList<String> signatures = Lists.newArrayList();
try {
- LOG.debug("Loading UDF '" + udfPath + "' from " + jarPath);
+ LOG.debug("Loading UDF '" + className + "' from " + jarPath);
ClassLoader loader;
if (jarPath != null) {
// Save for cleanup.
@@ -458,7 +169,7 @@ public class UdfExecutor {
// for test
loader = ClassLoader.getSystemClassLoader();
}
- Class<?> c = Class.forName(udfPath, true, loader);
+ Class<?> c = Class.forName(className, true, loader);
Constructor<?> ctor = c.getConstructor();
udf = ctor.newInstance();
Method[] methods = c.getMethods();
@@ -485,7 +196,7 @@ public class UdfExecutor {
retType = returnType.second;
}
argTypes = new JavaUdfDataType[0];
- LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath);
+ LOG.debug("Loaded UDF '" + className + "' from " + jarPath);
return;
}
returnType = UdfUtils.setReturnType(funcRetType, m.getReturnType());
@@ -500,13 +211,13 @@ public class UdfExecutor {
} else {
argTypes = inputType.second;
}
- LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath);
+ LOG.debug("Loaded UDF '" + className + "' from " + jarPath);
return;
}
StringBuilder sb = new StringBuilder();
sb.append("Unable to find evaluate function with the correct signature: ")
- .append(udfPath + ".evaluate(")
+ .append(className + ".evaluate(")
.append(Joiner.on(", ").join(parameterTypes))
.append(")\n")
.append("UDF contains: \n ")
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java
index 105daa21bc..d9fa55bb2f 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java
@@ -84,11 +84,16 @@ public class UdfUtils {
LARGEINT("LARGEINT", TPrimitiveType.LARGEINT, 16),
DECIMALV2("DECIMALV2", TPrimitiveType.DECIMALV2, 16),
DATEV2("DATEV2", TPrimitiveType.DATEV2, 4),
- DATETIMEV2("DATETIMEV2", TPrimitiveType.DATETIMEV2, 8);
+ DATETIMEV2("DATETIMEV2", TPrimitiveType.DATETIMEV2, 8),
+ DECIMAL32("DECIMAL32", TPrimitiveType.DECIMAL32, 4),
+ DECIMAL64("DECIMAL64", TPrimitiveType.DECIMAL64, 8),
+ DECIMAL128("DECIMAL128", TPrimitiveType.DECIMAL128I, 16);
private final String description;
private final TPrimitiveType thriftType;
private final int len;
+ private int precision;
+ private int scale;
JavaUdfDataType(String description, TPrimitiveType thriftType, int len) {
this.description = description;
@@ -135,7 +140,8 @@ public class UdfUtils {
} else if (c == BigInteger.class) {
return Sets.newHashSet(JavaUdfDataType.LARGEINT);
} else if (c == BigDecimal.class) {
- return Sets.newHashSet(JavaUdfDataType.DECIMALV2);
+ return Sets.newHashSet(JavaUdfDataType.DECIMALV2, JavaUdfDataType.DECIMAL32, JavaUdfDataType.DECIMAL64,
+ JavaUdfDataType.DECIMAL128);
}
return Sets.newHashSet(JavaUdfDataType.INVALID_TYPE);
}
@@ -151,6 +157,22 @@ public class UdfUtils {
}
return false;
}
+
+ public int getPrecision() {
+ return precision;
+ }
+
+ public void setPrecision(int precision) {
+ this.precision = precision;
+ }
+
+ public int getScale() {
+ return this.thriftType == TPrimitiveType.DECIMALV2 ? 9 : scale;
+ }
+
+ public void setScale(int scale) {
+ this.scale = scale;
+ }
}
protected static Pair<Type, Integer> fromThrift(TTypeDesc typeDesc, int nodeIdx) throws InternalException {
@@ -239,10 +261,13 @@ public class UdfUtils {
// type.
Object[] res = javaTypes.stream().filter(
t -> t.getPrimitiveType() == retType.getPrimitiveType().toThrift()).toArray();
- if (res.length == 0) {
- return Pair.of(false, (JavaUdfDataType) javaTypes.toArray()[0]);
+
+ JavaUdfDataType result = res.length == 0 ? (JavaUdfDataType) javaTypes.toArray()[0] : (JavaUdfDataType) res[0];
+ if (retType.isDecimalV3() || retType.isDatetimeV2()) {
+ result.setPrecision(retType.getPrecision());
+ result.setScale(((ScalarType) retType).getScalarScale());
}
- return Pair.of(true, (JavaUdfDataType) res[0]);
+ return Pair.of(res.length != 0, result);
}
/**
@@ -259,11 +284,13 @@ public class UdfUtils {
int finalI = i;
Object[] res = javaTypes.stream().filter(
t -> t.getPrimitiveType() == parameterTypes[finalI].getPrimitiveType().toThrift()).toArray();
+ inputArgTypes[i] = res.length == 0 ? (JavaUdfDataType) javaTypes.toArray()[0] : (JavaUdfDataType) res[0];
+ if (parameterTypes[finalI].isDecimalV3() || parameterTypes[finalI].isDatetimeV2()) {
+ inputArgTypes[i].setPrecision(parameterTypes[finalI].getPrecision());
+ inputArgTypes[i].setScale(((ScalarType) parameterTypes[finalI]).getScalarScale());
+ }
if (res.length == 0) {
- inputArgTypes[i] = (JavaUdfDataType) javaTypes.toArray()[0];
return Pair.of(false, inputArgTypes);
- } else {
- inputArgTypes[i] = (JavaUdfDataType) res[0];
}
}
return Pair.of(true, inputArgTypes);
diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py
index 6c9fb3eccd..68abbad9bc 100755
--- a/gensrc/script/doris_builtins_functions.py
+++ b/gensrc/script/doris_builtins_functions.py
@@ -123,6 +123,9 @@ visible_functions = [
[['array'], 'ARRAY', ['FLOAT', '...'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
[['array'], 'ARRAY', ['DOUBLE', '...'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
[['array'], 'ARRAY', ['DECIMALV2', '...'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
+ [['array'], 'ARRAY', ['DECIMAL32', '...'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
+ [['array'], 'ARRAY', ['DECIMAL64', '...'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
+ [['array'], 'ARRAY', ['DECIMAL128', '...'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
[['array'], 'ARRAY', ['VARCHAR', '...'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
[['array'], 'ARRAY', ['STRING', '...'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
@@ -139,6 +142,9 @@ visible_functions = [
[['element_at', '%element_extract%'], 'FLOAT', ['ARRAY_FLOAT', 'BIGINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['element_at', '%element_extract%'], 'DOUBLE', ['ARRAY_DOUBLE', 'BIGINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['element_at', '%element_extract%'], 'DECIMALV2', ['ARRAY_DECIMALV2', 'BIGINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['element_at', '%element_extract%'], 'DECIMAL32', ['ARRAY_DECIMAL32', 'BIGINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['element_at', '%element_extract%'], 'DECIMAL64', ['ARRAY_DECIMAL64', 'BIGINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['element_at', '%element_extract%'], 'DECIMAL128', ['ARRAY_DECIMAL128', 'BIGINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['element_at', '%element_extract%'], 'VARCHAR', ['ARRAY_VARCHAR', 'BIGINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['element_at', '%element_extract%'], 'STRING', ['ARRAY_STRING', 'BIGINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
@@ -155,6 +161,9 @@ visible_functions = [
[['arrays_overlap'], 'BOOLEAN', ['ARRAY_FLOAT', 'ARRAY_FLOAT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['arrays_overlap'], 'BOOLEAN', ['ARRAY_DOUBLE', 'ARRAY_DOUBLE'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['arrays_overlap'], 'BOOLEAN', ['ARRAY_DECIMALV2', 'ARRAY_DECIMALV2'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['arrays_overlap'], 'BOOLEAN', ['ARRAY_DECIMAL32', 'ARRAY_DECIMAL32'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['arrays_overlap'], 'BOOLEAN', ['ARRAY_DECIMAL64', 'ARRAY_DECIMAL64'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['arrays_overlap'], 'BOOLEAN', ['ARRAY_DECIMAL128', 'ARRAY_DECIMAL128'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['arrays_overlap'], 'BOOLEAN', ['ARRAY_VARCHAR', 'ARRAY_VARCHAR'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['arrays_overlap'], 'BOOLEAN', ['ARRAY_STRING', 'ARRAY_STRING'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
@@ -171,6 +180,9 @@ visible_functions = [
[['array_contains'], 'BOOLEAN', ['ARRAY_FLOAT', 'FLOAT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_contains'], 'BOOLEAN', ['ARRAY_DOUBLE', 'DOUBLE'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_contains'], 'BOOLEAN', ['ARRAY_DECIMALV2', 'DECIMALV2'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_contains'], 'BOOLEAN', ['ARRAY_DECIMAL32', 'DECIMAL32'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_contains'], 'BOOLEAN', ['ARRAY_DECIMAL64', 'DECIMAL64'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_contains'], 'BOOLEAN', ['ARRAY_DECIMAL128', 'DECIMAL128'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_contains'], 'BOOLEAN', ['ARRAY_VARCHAR', 'VARCHAR'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_contains'], 'BOOLEAN', ['ARRAY_STRING', 'STRING'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
@@ -188,6 +200,9 @@ visible_functions = [
[['array_enumerate'], 'ARRAY_BIGINT', ['ARRAY_FLOAT'], '', '', '', 'vec', ''],
[['array_enumerate'], 'ARRAY_BIGINT', ['ARRAY_DOUBLE'], '', '', '', 'vec', ''],
[['array_enumerate'], 'ARRAY_BIGINT', ['ARRAY_DECIMALV2'], '', '', '', 'vec', ''],
+ [['array_enumerate'], 'ARRAY_BIGINT', ['ARRAY_DECIMAL32'], '', '', '', 'vec', ''],
+ [['array_enumerate'], 'ARRAY_BIGINT', ['ARRAY_DECIMAL64'], '', '', '', 'vec', ''],
+ [['array_enumerate'], 'ARRAY_BIGINT', ['ARRAY_DECIMAL128'], '', '', '', 'vec', ''],
[['array_enumerate'], 'ARRAY_BIGINT', ['ARRAY_VARCHAR'], '', '', '', 'vec', ''],
[['array_enumerate'], 'ARRAY_BIGINT', ['ARRAY_STRING'], '', '', '', 'vec', ''],
@@ -204,6 +219,9 @@ visible_functions = [
[['countequal'], 'BIGINT', ['ARRAY_FLOAT', 'FLOAT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['countequal'], 'BIGINT', ['ARRAY_DOUBLE', 'DOUBLE'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['countequal'], 'BIGINT', ['ARRAY_DECIMALV2', 'DECIMALV2'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['countequal'], 'BIGINT', ['ARRAY_DECIMAL32', 'DECIMAL32'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['countequal'], 'BIGINT', ['ARRAY_DECIMAL64', 'DECIMAL64'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['countequal'], 'BIGINT', ['ARRAY_DECIMAL128', 'DECIMAL128'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['countequal'], 'BIGINT', ['ARRAY_VARCHAR', 'VARCHAR'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['countequal'], 'BIGINT', ['ARRAY_STRING', 'STRING'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
@@ -220,6 +238,9 @@ visible_functions = [
[['array_position'], 'BIGINT', ['ARRAY_FLOAT', 'FLOAT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_position'], 'BIGINT', ['ARRAY_DOUBLE', 'DOUBLE'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_position'], 'BIGINT', ['ARRAY_DECIMALV2', 'DECIMALV2'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_position'], 'BIGINT', ['ARRAY_DECIMAL32', 'DECIMAL32'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_position'], 'BIGINT', ['ARRAY_DECIMAL64', 'DECIMAL64'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_position'], 'BIGINT', ['ARRAY_DECIMAL128', 'DECIMAL128'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_position'], 'BIGINT', ['ARRAY_VARCHAR', 'VARCHAR'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_position'], 'BIGINT', ['ARRAY_STRING', 'STRING'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
@@ -237,6 +258,9 @@ visible_functions = [
[['array_distinct'], 'ARRAY_FLOAT', ['ARRAY_FLOAT'], '', '', '', 'vec', ''],
[['array_distinct'], 'ARRAY_DOUBLE', ['ARRAY_DOUBLE'], '', '', '', 'vec', ''],
[['array_distinct'], 'ARRAY_DECIMALV2', ['ARRAY_DECIMALV2'], '', '', '', 'vec', ''],
+ [['array_distinct'], 'ARRAY_DECIMAL32', ['ARRAY_DECIMAL32'], '', '', '', 'vec', ''],
+ [['array_distinct'], 'ARRAY_DECIMAL64', ['ARRAY_DECIMAL64'], '', '', '', 'vec', ''],
+ [['array_distinct'], 'ARRAY_DECIMAL128', ['ARRAY_DECIMAL128'], '', '', '', 'vec', ''],
[['array_distinct'], 'ARRAY_VARCHAR', ['ARRAY_VARCHAR'], '', '', '', 'vec', ''],
[['array_distinct'], 'ARRAY_STRING', ['ARRAY_STRING'], '', '', '', 'vec', ''],
@@ -248,6 +272,9 @@ visible_functions = [
[['array_difference'], 'ARRAY_DOUBLE', ['ARRAY_FLOAT'], '', '', '', 'vec', ''],
[['array_difference'], 'ARRAY_DOUBLE', ['ARRAY_DOUBLE'], '', '', '', 'vec', ''],
[['array_difference'], 'ARRAY_DECIMALV2', ['ARRAY_DECIMALV2'], '', '', '', 'vec', ''],
+ [['array_difference'], 'ARRAY_DECIMAL32', ['ARRAY_DECIMAL32'], '', '', '', 'vec', ''],
+ [['array_difference'], 'ARRAY_DECIMAL64', ['ARRAY_DECIMAL64'], '', '', '', 'vec', ''],
+ [['array_difference'], 'ARRAY_DECIMAL128', ['ARRAY_DECIMAL128'], '', '', '', 'vec', ''],
[['array_sort'], 'ARRAY_BOOLEAN', ['ARRAY_BOOLEAN'], '', '', '', 'vec', ''],
[['array_sort'], 'ARRAY_TINYINT', ['ARRAY_TINYINT'], '', '', '', 'vec', ''],
@@ -262,6 +289,9 @@ visible_functions = [
[['array_sort'], 'ARRAY_FLOAT', ['ARRAY_FLOAT'], '', '', '', 'vec', ''],
[['array_sort'], 'ARRAY_DOUBLE', ['ARRAY_DOUBLE'], '', '', '', 'vec', ''],
[['array_sort'], 'ARRAY_DECIMALV2', ['ARRAY_DECIMALV2'], '', '', '', 'vec', ''],
+ [['array_sort'], 'ARRAY_DECIMAL32', ['ARRAY_DECIMAL32'], '', '', '', 'vec', ''],
+ [['array_sort'], 'ARRAY_DECIMAL64', ['ARRAY_DECIMAL64'], '', '', '', 'vec', ''],
+ [['array_sort'], 'ARRAY_DECIMAL128', ['ARRAY_DECIMAL128'], '', '', '', 'vec', ''],
[['array_sort'], 'ARRAY_VARCHAR', ['ARRAY_VARCHAR'], '', '', '', 'vec', ''],
[['array_sort'], 'ARRAY_STRING', ['ARRAY_STRING'], '', '', '', 'vec', ''],
@@ -279,6 +309,9 @@ visible_functions = [
[['array_join'], 'STRING', ['ARRAY_FLOAT','VARCHAR'], '', '', '', 'vec', ''],
[['array_join'], 'STRING', ['ARRAY_DOUBLE','VARCHAR'], '', '', '', 'vec', ''],
[['array_join'], 'STRING', ['ARRAY_DECIMALV2','VARCHAR'], '', '', '', 'vec', ''],
+ [['array_join'], 'STRING', ['ARRAY_DECIMAL32','VARCHAR'], '', '', '', 'vec', ''],
+ [['array_join'], 'STRING', ['ARRAY_DECIMAL64','VARCHAR'], '', '', '', 'vec', ''],
+ [['array_join'], 'STRING', ['ARRAY_DECIMAL128','VARCHAR'], '', '', '', 'vec', ''],
[['array_join'], 'STRING', ['ARRAY_VARCHAR','VARCHAR'], '', '', '', 'vec', ''],
[['array_join'], 'STRING', ['ARRAY_STRING','VARCHAR'], '', '', '', 'vec', ''],
# array_join takes three params
@@ -295,6 +328,9 @@ visible_functions = [
[['array_join'], 'STRING', ['ARRAY_FLOAT','VARCHAR', 'VARCHAR'], '', '', '', 'vec', ''],
[['array_join'], 'STRING', ['ARRAY_DOUBLE','VARCHAR', 'VARCHAR'], '', '', '', 'vec', ''],
[['array_join'], 'STRING', ['ARRAY_DECIMALV2','VARCHAR', 'VARCHAR'], '', '', '', 'vec', ''],
+ [['array_join'], 'STRING', ['ARRAY_DECIMAL32','VARCHAR', 'VARCHAR'], '', '', '', 'vec', ''],
+ [['array_join'], 'STRING', ['ARRAY_DECIMAL64','VARCHAR', 'VARCHAR'], '', '', '', 'vec', ''],
+ [['array_join'], 'STRING', ['ARRAY_DECIMAL128','VARCHAR', 'VARCHAR'], '', '', '', 'vec', ''],
[['array_join'], 'STRING', ['ARRAY_VARCHAR','VARCHAR', 'VARCHAR'], '', '', '', 'vec', ''],
[['array_join'], 'STRING', ['ARRAY_STRING','VARCHAR', 'VARCHAR'], '', '', '', 'vec', ''],
@@ -307,6 +343,9 @@ visible_functions = [
[['array_min'], 'FLOAT', ['ARRAY_FLOAT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_min'], 'DOUBLE', ['ARRAY_DOUBLE'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_min'], 'DECIMALV2',['ARRAY_DECIMALV2'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_min'], 'DECIMAL32',['ARRAY_DECIMAL32'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_min'], 'DECIMAL64',['ARRAY_DECIMAL64'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_min'], 'DECIMAL128',['ARRAY_DECIMAL128'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_min'], 'DATE', ['ARRAY_DATE'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_min'], 'DATETIME', ['ARRAY_DATETIME'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_min'], 'DATEV2', ['ARRAY_DATEV2'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
@@ -320,6 +359,9 @@ visible_functions = [
[['array_max'], 'FLOAT', ['ARRAY_FLOAT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_max'], 'DOUBLE', ['ARRAY_DOUBLE'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_max'], 'DECIMALV2',['ARRAY_DECIMALV2'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_max'], 'DECIMAL32',['ARRAY_DECIMAL32'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_max'], 'DECIMAL64',['ARRAY_DECIMAL64'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_max'], 'DECIMAL128',['ARRAY_DECIMAL128'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_max'], 'DATE', ['ARRAY_DATE'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_max'], 'DATETIME', ['ARRAY_DATETIME'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_max'], 'DATEV2', ['ARRAY_DATEV2'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
@@ -333,6 +375,9 @@ visible_functions = [
[['array_sum'], 'DOUBLE', ['ARRAY_FLOAT'], '', '', '','vec', 'ALWAYS_NULLABLE'],
[['array_sum'], 'DOUBLE', ['ARRAY_DOUBLE'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_sum'], 'DECIMALV2',['ARRAY_DECIMALV2'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_sum'], 'DECIMAL32',['ARRAY_DECIMAL32'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_sum'], 'DECIMAL64',['ARRAY_DECIMAL64'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_sum'], 'DECIMAL128',['ARRAY_DECIMAL128'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_avg'], 'DOUBLE', ['ARRAY_BOOLEAN'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_avg'], 'DOUBLE', ['ARRAY_TINYINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_avg'], 'DOUBLE', ['ARRAY_SMALLINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
@@ -342,6 +387,9 @@ visible_functions = [
[['array_avg'], 'DOUBLE', ['ARRAY_FLOAT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_avg'], 'DOUBLE', ['ARRAY_DOUBLE'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_avg'], 'DECIMALV2',['ARRAY_DECIMALV2'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_avg'], 'DECIMAL32',['ARRAY_DECIMAL32'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_avg'], 'DECIMAL64',['ARRAY_DECIMAL64'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_avg'], 'DECIMAL128',['ARRAY_DECIMAL128'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_product'], 'DOUBLE', ['ARRAY_BOOLEAN'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_product'], 'DOUBLE', ['ARRAY_TINYINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_product'], 'DOUBLE', ['ARRAY_SMALLINT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
@@ -351,6 +399,9 @@ visible_functions = [
[['array_product'], 'DOUBLE', ['ARRAY_FLOAT'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_product'], 'DOUBLE', ['ARRAY_DOUBLE'], '', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_product'], 'DECIMALV2',['ARRAY_DECIMALV2'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_product'], 'DECIMAL32',['ARRAY_DECIMAL32'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_product'], 'DECIMAL64',['ARRAY_DECIMAL64'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
+ [['array_product'], 'DECIMAL128',['ARRAY_DECIMAL128'],'', '', '', 'vec', 'ALWAYS_NULLABLE'],
[['array_remove'], 'ARRAY_BOOLEAN', ['ARRAY_BOOLEAN', 'BOOLEAN'], '', '', '', 'vec', ''],
[['array_remove'], 'ARRAY_TINYINT', ['ARRAY_TINYINT', 'TINYINT'], '', '', '', 'vec', ''],
@@ -361,6 +412,9 @@ visible_functions = [
[['array_remove'], 'ARRAY_FLOAT', ['ARRAY_FLOAT', 'FLOAT'], '', '', '', 'vec', ''],
[['array_remove'], 'ARRAY_DOUBLE', ['ARRAY_DOUBLE', 'DOUBLE'], '', '', '', 'vec', ''],
[['array_remove'], 'ARRAY_DECIMALV2', ['ARRAY_DECIMALV2', 'DECIMALV2'], '', '', '', 'vec', ''],
+ [['array_remove'], 'ARRAY_DECIMAL32', ['ARRAY_DECIMAL32', 'DECIMAL32'], '', '', '', 'vec', ''],
+ [['array_remove'], 'ARRAY_DECIMAL64', ['ARRAY_DECIMAL64', 'DECIMAL64'], '', '', '', 'vec', ''],
+ [['array_remove'], 'ARRAY_DECIMAL128', ['ARRAY_DECIMAL128', 'DECIMAL128'], '', '', '', 'vec', ''],
[['array_remove'], 'ARRAY_DATETIME', ['ARRAY_DATETIME', 'DATETIME'], '', '', '', 'vec', ''],
[['array_remove'], 'ARRAY_DATE', ['ARRAY_DATE', 'DATE'], '', '', '', 'vec', ''],
[['array_remove'], 'ARRAY_DATETIMEV2', ['ARRAY_DATETIMEV2', 'DATETIMEV2'], '', '', '', 'vec', ''],
@@ -377,6 +431,9 @@ visible_functions = [
[['array_union'], 'ARRAY_FLOAT', ['ARRAY_FLOAT', 'ARRAY_FLOAT'], '', '', '', 'vec', ''],
[['array_union'], 'ARRAY_DOUBLE', ['ARRAY_DOUBLE', 'ARRAY_DOUBLE'], '', '', '', 'vec', ''],
[['array_union'], 'ARRAY_DECIMALV2', ['ARRAY_DECIMALV2', 'ARRAY_DECIMALV2'], '', '', '', 'vec', ''],
+ [['array_union'], 'ARRAY_DECIMAL32', ['ARRAY_DECIMAL32', 'ARRAY_DECIMAL32'], '', '', '', 'vec', ''],
+ [['array_union'], 'ARRAY_DECIMAL64', ['ARRAY_DECIMAL64', 'ARRAY_DECIMAL64'], '', '', '', 'vec', ''],
+ [['array_union'], 'ARRAY_DECIMAL128', ['ARRAY_DECIMAL128', 'ARRAY_DECIMAL128'], '', '', '', 'vec', ''],
[['array_union'], 'ARRAY_DATETIME', ['ARRAY_DATETIME', 'ARRAY_DATETIME'], '', '', '', 'vec', ''],
[['array_union'], 'ARRAY_DATE', ['ARRAY_DATE', 'ARRAY_DATE'], '', '', '', 'vec', ''],
[['array_union'], 'ARRAY_DATETIMEV2', ['ARRAY_DATETIMEV2', 'ARRAY_DATETIMEV2'], '', '', '', 'vec', ''],
@@ -393,6 +450,9 @@ visible_functions = [
[['array_except'], 'ARRAY_FLOAT', ['ARRAY_FLOAT', 'ARRAY_FLOAT'], '', '', '', 'vec', ''],
[['array_except'], 'ARRAY_DOUBLE', ['ARRAY_DOUBLE', 'ARRAY_DOUBLE'], '', '', '', 'vec', ''],
[['array_except'], 'ARRAY_DECIMALV2', ['ARRAY_DECIMALV2', 'ARRAY_DECIMALV2'], '', '', '', 'vec', ''],
+ [['array_except'], 'ARRAY_DECIMAL32', ['ARRAY_DECIMAL32', 'ARRAY_DECIMAL32'], '', '', '', 'vec', ''],
+ [['array_except'], 'ARRAY_DECIMAL64', ['ARRAY_DECIMAL64', 'ARRAY_DECIMAL64'], '', '', '', 'vec', ''],
+ [['array_except'], 'ARRAY_DECIMAL128', ['ARRAY_DECIMAL128', 'ARRAY_DECIMAL128'], '', '', '', 'vec', ''],
[['array_except'], 'ARRAY_DATETIME', ['ARRAY_DATETIME', 'ARRAY_DATETIME'], '', '', '', 'vec', ''],
[['array_except'], 'ARRAY_DATE', ['ARRAY_DATE', 'ARRAY_DATE'], '', '', '', 'vec', ''],
[['array_except'], 'ARRAY_DATETIMEV2', ['ARRAY_DATETIMEV2', 'ARRAY_DATETIMEV2'], '', '', '', 'vec', ''],
@@ -413,6 +473,9 @@ visible_functions = [
[['array_compact'], 'ARRAY_FLOAT', ['ARRAY_FLOAT'], '', '', '', 'vec', ''],
[['array_compact'], 'ARRAY_DOUBLE', ['ARRAY_DOUBLE'], '', '', '', 'vec', ''],
[['array_compact'], 'ARRAY_DECIMALV2', ['ARRAY_DECIMALV2'], '', '', '', 'vec', ''],
+ [['array_compact'], 'ARRAY_DECIMAL32', ['ARRAY_DECIMAL32'], '', '', '', 'vec', ''],
+ [['array_compact'], 'ARRAY_DECIMAL64', ['ARRAY_DECIMAL64'], '', '', '', 'vec', ''],
+ [['array_compact'], 'ARRAY_DECIMAL128', ['ARRAY_DECIMAL128'], '', '', '', 'vec', ''],
[['array_compact'], 'ARRAY_VARCHAR', ['ARRAY_VARCHAR'], '', '', '', 'vec', ''],
[['array_intersect'], 'ARRAY_BOOLEAN', ['ARRAY_BOOLEAN', 'ARRAY_BOOLEAN'], '', '', '', 'vec', ''],
@@ -424,6 +487,9 @@ visible_functions = [
[['array_intersect'], 'ARRAY_FLOAT', ['ARRAY_FLOAT', 'ARRAY_FLOAT'], '', '', '', 'vec', ''],
[['array_intersect'], 'ARRAY_DOUBLE', ['ARRAY_DOUBLE', 'ARRAY_DOUBLE'], '', '', '', 'vec', ''],
[['array_intersect'], 'ARRAY_DECIMALV2', ['ARRAY_DECIMALV2', 'ARRAY_DECIMALV2'], '', '', '', 'vec', ''],
+ [['array_intersect'], 'ARRAY_DECIMAL32', ['ARRAY_DECIMAL32', 'ARRAY_DECIMAL32'], '', '', '', 'vec', ''],
+ [['array_intersect'], 'ARRAY_DECIMAL64', ['ARRAY_DECIMAL64', 'ARRAY_DECIMAL64'], '', '', '', 'vec', ''],
+ [['array_intersect'], 'ARRAY_DECIMAL128', ['ARRAY_DECIMAL128', 'ARRAY_DECIMAL128'], '', '', '', 'vec', ''],
[['array_intersect'], 'ARRAY_DATETIME', ['ARRAY_DATETIME', 'ARRAY_DATETIME'], '', '', '', 'vec', ''],
[['array_intersect'], 'ARRAY_DATE', ['ARRAY_DATE', 'ARRAY_DATE'], '', '', '', 'vec', ''],
[['array_intersect'], 'ARRAY_DATETIMEV2', ['ARRAY_DATETIMEV2', 'ARRAY_DATETIMEV2'], '', '', '', 'vec', ''],
@@ -442,6 +508,9 @@ visible_functions = [
[['array_slice', '%element_slice%'], 'ARRAY_FLOAT', ['ARRAY_FLOAT', 'BIGINT'], '', '', '', 'vec', ''],
[['array_slice', '%element_slice%'], 'ARRAY_DOUBLE', ['ARRAY_DOUBLE', 'BIGINT'], '', '', '', 'vec', ''],
[['array_slice', '%element_slice%'], 'ARRAY_DECIMALV2', ['ARRAY_DECIMALV2', 'BIGINT'], '', '', '', 'vec', ''],
+ [['array_slice', '%element_slice%'], 'ARRAY_DECIMAL32', ['ARRAY_DECIMAL32', 'BIGINT'], '', '', '', 'vec', ''],
+ [['array_slice', '%element_slice%'], 'ARRAY_DECIMAL64', ['ARRAY_DECIMAL64', 'BIGINT'], '', '', '', 'vec', ''],
+ [['array_slice', '%element_slice%'], 'ARRAY_DECIMAL128', ['ARRAY_DECIMAL128', 'BIGINT'], '', '', '', 'vec', ''],
[['array_slice', '%element_slice%'], 'ARRAY_VARCHAR', ['ARRAY_VARCHAR', 'BIGINT'], '', '', '', 'vec', ''],
[['array_slice', '%element_slice%'], 'ARRAY_STRING', ['ARRAY_STRING', 'BIGINT'], '', '', '', 'vec', ''],
@@ -456,6 +525,9 @@ visible_functions = [
[['array_slice', '%element_slice%'], 'ARRAY_FLOAT', ['ARRAY_FLOAT', 'BIGINT', 'BIGINT'], '', '', '', 'vec', ''],
[['array_slice', '%element_slice%'], 'ARRAY_DOUBLE', ['ARRAY_DOUBLE', 'BIGINT', 'BIGINT'], '', '', '', 'vec', ''],
[['array_slice', '%element_slice%'], 'ARRAY_DECIMALV2', ['ARRAY_DECIMALV2', 'BIGINT', 'BIGINT'], '', '', '', 'vec', ''],
+ [['array_slice', '%element_slice%'], 'ARRAY_DECIMAL32', ['ARRAY_DECIMAL32', 'BIGINT', 'BIGINT'], '', '', '', 'vec', ''],
+ [['array_slice', '%element_slice%'], 'ARRAY_DECIMAL64', ['ARRAY_DECIMAL64', 'BIGINT', 'BIGINT'], '', '', '', 'vec', ''],
+ [['array_slice', '%element_slice%'], 'ARRAY_DECIMAL128', ['ARRAY_DECIMAL128', 'BIGINT', 'BIGINT'], '', '', '', 'vec', ''],
[['array_slice', '%element_slice%'], 'ARRAY_VARCHAR', ['ARRAY_VARCHAR', 'BIGINT', 'BIGINT'], '', '', '', 'vec', ''],
[['array_slice', '%element_slice%'], 'ARRAY_STRING', ['ARRAY_STRING', 'BIGINT', 'BIGINT'], '', '', '', 'vec', ''],
@@ -470,6 +542,9 @@ visible_functions = [
[['array_popback'], 'ARRAY_FLOAT', ['ARRAY_FLOAT'], '', '', '', 'vec', ''],
[['array_popback'], 'ARRAY_DOUBLE', ['ARRAY_DOUBLE'], '', '', '', 'vec', ''],
[['array_popback'], 'ARRAY_DECIMALV2', ['ARRAY_DECIMALV2'], '', '', '', 'vec', ''],
+ [['array_popback'], 'ARRAY_DECIMAL32', ['ARRAY_DECIMAL32'], '', '', '', 'vec', ''],
+ [['array_popback'], 'ARRAY_DECIMAL64', ['ARRAY_DECIMAL64'], '', '', '', 'vec', ''],
+ [['array_popback'], 'ARRAY_DECIMAL128', ['ARRAY_DECIMAL128'], '', '', '', 'vec', ''],
[['array_popback'], 'ARRAY_VARCHAR', ['ARRAY_VARCHAR'], '', '', '', 'vec', ''],
[['array_popback'], 'ARRAY_STRING', ['ARRAY_STRING'], '', '', '', 'vec', ''],
@@ -484,6 +559,9 @@ visible_functions = [
[['array_with_constant'], 'ARRAY_FLOAT', ['BIGINT', 'FLOAT'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
[['array_with_constant'], 'ARRAY_DOUBLE', ['BIGINT', 'DOUBLE'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
[['array_with_constant'], 'ARRAY_DECIMALV2', ['BIGINT', 'DECIMALV2'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
+ [['array_with_constant'], 'ARRAY_DECIMAL32', ['BIGINT', 'DECIMAL32'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
+ [['array_with_constant'], 'ARRAY_DECIMAL64', ['BIGINT', 'DECIMAL64'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
+ [['array_with_constant'], 'ARRAY_DECIMAL128', ['BIGINT', 'DECIMAL128'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
[['array_with_constant'], 'ARRAY_VARCHAR', ['BIGINT', 'VARCHAR'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
[['array_with_constant'], 'ARRAY_STRING', ['BIGINT', 'STRING'], '', '', '', 'vec', 'ALWAYS_NOT_NULLABLE'],
@@ -507,6 +585,9 @@ visible_functions = [
[['reverse'], 'ARRAY_FLOAT', ['ARRAY_FLOAT'], '', '', '', 'vec', ''],
[['reverse'], 'ARRAY_DOUBLE', ['ARRAY_DOUBLE'], '', '', '', 'vec', ''],
[['reverse'], 'ARRAY_DECIMALV2', ['ARRAY_DECIMALV2'], '', '', '', 'vec', ''],
+ [['reverse'], 'ARRAY_DECIMAL32', ['ARRAY_DECIMAL32'], '', '', '', 'vec', ''],
+ [['reverse'], 'ARRAY_DECIMAL64', ['ARRAY_DECIMAL64'], '', '', '', 'vec', ''],
+ [['reverse'], 'ARRAY_DECIMAL128', ['ARRAY_DECIMAL128'], '', '', '', 'vec', ''],
[['reverse'], 'ARRAY_VARCHAR', ['ARRAY_VARCHAR'], '', '', '', 'vec', ''],
[['reverse'], 'ARRAY_STRING', ['ARRAY_STRING'], '', '', '', 'vec', ''],
@@ -1354,94 +1435,6 @@ visible_functions = [
'_ZN5doris18TimestampFunctions12seconds_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
'', '', 'vec', 'ALWAYS_NULLABLE'],
- [['years_diff'], 'BIGINT', ['DATEV2', 'DATETIME'],
- '_ZN5doris18TimestampFunctions10years_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['months_diff'], 'BIGINT', ['DATEV2', 'DATETIME'],
- '_ZN5doris18TimestampFunctions11months_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['weeks_diff'], 'BIGINT', ['DATEV2', 'DATETIME'],
- '_ZN5doris18TimestampFunctions10weeks_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['days_diff'], 'BIGINT', ['DATEV2', 'DATETIME'],
- '_ZN5doris18TimestampFunctions9days_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['hours_diff'], 'BIGINT', ['DATEV2', 'DATETIME'],
- '_ZN5doris18TimestampFunctions10hours_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['minutes_diff'], 'BIGINT', ['DATEV2', 'DATETIME'],
- '_ZN5doris18TimestampFunctions12minutes_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['seconds_diff'], 'BIGINT', ['DATEV2', 'DATETIME'],
- '_ZN5doris18TimestampFunctions12seconds_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
-
- [['years_diff'], 'BIGINT', ['DATETIME', 'DATEV2'],
- '_ZN5doris18TimestampFunctions10years_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['months_diff'], 'BIGINT', ['DATETIME', 'DATEV2'],
- '_ZN5doris18TimestampFunctions11months_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['weeks_diff'], 'BIGINT', ['DATETIME', 'DATEV2'],
- '_ZN5doris18TimestampFunctions10weeks_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['days_diff'], 'BIGINT', ['DATETIME', 'DATEV2'],
- '_ZN5doris18TimestampFunctions9days_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['hours_diff'], 'BIGINT', ['DATETIME', 'DATEV2'],
- '_ZN5doris18TimestampFunctions10hours_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['minutes_diff'], 'BIGINT', ['DATETIME', 'DATEV2'],
- '_ZN5doris18TimestampFunctions12minutes_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['seconds_diff'], 'BIGINT', ['DATETIME', 'DATEV2'],
- '_ZN5doris18TimestampFunctions12seconds_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
-
- [['years_diff'], 'BIGINT', ['DATETIMEV2', 'DATETIME'],
- '_ZN5doris18TimestampFunctions10years_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['months_diff'], 'BIGINT', ['DATETIMEV2', 'DATETIME'],
- '_ZN5doris18TimestampFunctions11months_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['weeks_diff'], 'BIGINT', ['DATETIMEV2', 'DATETIME'],
- '_ZN5doris18TimestampFunctions10weeks_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['days_diff'], 'BIGINT', ['DATETIMEV2', 'DATETIME'],
- '_ZN5doris18TimestampFunctions9days_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['hours_diff'], 'BIGINT', ['DATETIMEV2', 'DATETIME'],
- '_ZN5doris18TimestampFunctions10hours_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['minutes_diff'], 'BIGINT', ['DATETIMEV2', 'DATETIME'],
- '_ZN5doris18TimestampFunctions12minutes_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['seconds_diff'], 'BIGINT', ['DATETIMEV2', 'DATETIME'],
- '_ZN5doris18TimestampFunctions12seconds_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
-
- [['years_diff'], 'BIGINT', ['DATETIME', 'DATETIMEV2'],
- '_ZN5doris18TimestampFunctions10years_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['months_diff'], 'BIGINT', ['DATETIME', 'DATETIMEV2'],
- '_ZN5doris18TimestampFunctions11months_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['weeks_diff'], 'BIGINT', ['DATETIME', 'DATETIMEV2'],
- '_ZN5doris18TimestampFunctions10weeks_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['days_diff'], 'BIGINT', ['DATETIME', 'DATETIMEV2'],
- '_ZN5doris18TimestampFunctions9days_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['hours_diff'], 'BIGINT', ['DATETIME', 'DATETIMEV2'],
- '_ZN5doris18TimestampFunctions10hours_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['minutes_diff'], 'BIGINT', ['DATETIME', 'DATETIMEV2'],
- '_ZN5doris18TimestampFunctions12minutes_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
- [['seconds_diff'], 'BIGINT', ['DATETIME', 'DATETIMEV2'],
- '_ZN5doris18TimestampFunctions12seconds_diffEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValES6_',
- '', '', 'vec', 'ALWAYS_NULLABLE'],
-
[['year_floor'], 'DATETIMEV2', ['DATETIMEV2'],
'_ZN5doris18TimestampFunctions10year_floorEPN9doris_udf15FunctionContextERKNS1_11DateTimeV2ValE',
'', '', 'vec', 'ALWAYS_NULLABLE'],
@@ -2189,10 +2182,8 @@ visible_functions = [
[['ifnull', 'nvl'], 'DATETIME', ['DATETIME', 'DATETIME'], '', '', '', 'vec', 'CUSTOM'],
[['ifnull', 'nvl'], 'DATETIME', ['DATE', 'DATETIME'], '', '', '', 'vec', 'CUSTOM'],
[['ifnull', 'nvl'], 'DATETIME', ['DATETIME', 'DATE'], '', '', '', 'vec', 'CUSTOM'],
- [['ifnull', 'nvl'], 'DATEV2', ['DATE', 'DATE'], '', '', '', 'vec', 'CUSTOM'],
+ [['ifnull', 'nvl'], 'DATEV2', ['DATEV2', 'DATEV2'], '', '', '', 'vec', 'CUSTOM'],
[['ifnull', 'nvl'], 'DATETIMEV2', ['DATETIMEV2', 'DATETIMEV2'], '', '', '', 'vec', 'CUSTOM'],
- [['ifnull', 'nvl'], 'DATETIMEV2', ['DATEV2', 'DATETIMEV2'], '', '', '', 'vec', 'CUSTOM'],
- [['ifnull', 'nvl'], 'DATETIMEV2', ['DATETIMEV2', 'DATEV2'], '', '', '', 'vec', 'CUSTOM'],
[['ifnull', 'nvl'], 'DECIMALV2', ['DECIMALV2', 'DECIMALV2'], '', '', '', 'vec', 'CUSTOM'],
[['ifnull', 'nvl'], 'DECIMAL32', ['DECIMAL32', 'DECIMAL32'], '', '', '', 'vec', 'CUSTOM'],
[['ifnull', 'nvl'], 'DECIMAL64', ['DECIMAL64', 'DECIMAL64'], '', '', '', 'vec', 'CUSTOM'],
@@ -2402,8 +2393,8 @@ visible_functions = [
[['running_difference'], 'DECIMAL128', ['DECIMAL128'], '', '', '', 'vec', ''],
[['running_difference'], 'INT', ['DATE'], '', '', '', 'vec', ''],
[['running_difference'], 'INT', ['DATEV2'], '', '', '', 'vec', ''],
- [['running_difference'], 'DOUBLE', ['DATETIME'], '', '', '', 'vec', ''],
- [['running_difference'], 'DOUBLE', ['DATETIMEV2'], '', '', '', 'vec', ''],
+ [['running_difference'], 'TIME', ['DATETIME'], '', '', '', 'vec', ''],
+ [['running_difference'], 'TIMEV2', ['DATETIMEV2'], '', '', '', 'vec', ''],
# Longtext function
[['substr', 'substring'], 'STRING', ['STRING', 'INT'],
diff --git a/gensrc/thrift/PaloInternalService.thrift b/gensrc/thrift/PaloInternalService.thrift
index 47a9144269..5d52e76f07 100644
--- a/gensrc/thrift/PaloInternalService.thrift
+++ b/gensrc/thrift/PaloInternalService.thrift
@@ -183,6 +183,8 @@ struct TQueryOptions {
53: optional i32 partitioned_hash_join_rows_threshold = 0
54: optional bool enable_share_hash_table_for_broadcast_join
+
+ 55: optional bool check_overflow_for_decimal = false
}
diff --git a/regression-test/data/correctness_p0/test_pushdown_constant.out b/regression-test/data/correctness_p0/test_pushdown_constant.out
index 095c7b2035..724665a290 100644
--- a/regression-test/data/correctness_p0/test_pushdown_constant.out
+++ b/regression-test/data/correctness_p0/test_pushdown_constant.out
@@ -2,3 +2,9 @@
-- !sql --
1
+-- !select_all --
+2022-01-01 2022-01-01T11:11:11
+
+-- !predicate --
+2022-01-01 2022-01-01T11:11:11
+
diff --git a/regression-test/data/data_model_p0/duplicate/storage/test_dup_tab_datetime_nullable.out b/regression-test/data/data_model_p0/duplicate/storage/test_dup_tab_datetime_nullable.out
index ec9b5d3bc5..9a7344d7e1 100644
--- a/regression-test/data/data_model_p0/duplicate/storage/test_dup_tab_datetime_nullable.out
+++ b/regression-test/data/data_model_p0/duplicate/storage/test_dup_tab_datetime_nullable.out
@@ -127,7 +127,6 @@
-- !datetime_as_pred_3 --
-- !datetime_as_pred_4 --
-\N
2021-01-02T23:10:04.111
2021-02-02T23:10:04.111
2021-03-02T23:10:04.111
diff --git a/regression-test/data/correctness_p0/test_pushdown_constant.out b/regression-test/data/datatype_p0/decimalv3/test_load.out
similarity index 53%
copy from regression-test/data/correctness_p0/test_pushdown_constant.out
copy to regression-test/data/datatype_p0/decimalv3/test_load.out
index 095c7b2035..35c355781e 100644
--- a/regression-test/data/correctness_p0/test_pushdown_constant.out
+++ b/regression-test/data/datatype_p0/decimalv3/test_load.out
@@ -1,4 +1,6 @@
-- This file is automatically generated. You should know what you did if you want to edit this
--- !sql --
-1
+-- !select_default --
+0.000132253565002480
+0.000135039891190100
+0.000390160098269360
diff --git a/regression-test/data/datatype_p0/decimalv3/test_overflow.out b/regression-test/data/datatype_p0/decimalv3/test_overflow.out
new file mode 100644
index 0000000000..c9b9873cd7
--- /dev/null
+++ b/regression-test/data/datatype_p0/decimalv3/test_overflow.out
@@ -0,0 +1,19 @@
+-- This file is automatically generated. You should know what you did if you want to edit this
+-- !select_all --
+11111111111111111111.100000000000000000 11111111111111111111.200000000000000000 11111111111111111111.300000000000000000 1.1000000000000000000000000000000000000 1.2000000000000000000000000000000000000 1.3000000000000000000000000000000000000 9
+
+-- !select_check_overflow1 --
+\N \N \N 99999999999999999999.900000000000000000 \N
+
+-- !select_check_overflow2 --
+1.1000000000000000000000000000000000000 111111111111111111111.000000000000000000 \N
+
+-- !select_check_overflow3 --
+11111111111111111111.100000000000000000 \N
+
+-- !select_not_check_overflow1 --
+99.999999999999999999999999999999999999 99.999999999999999999999999999999999999 1.1111111111111111E21 99999999999999999999.900000000000000000 99999999999999999999.999999999999999999
+
+-- !select_not_check_overflow2 --
+1.1000000000000000000000000000000000000 111111111111111111111.000000000000000000 -15.9141183460469231731687303715884105728
+
diff --git a/regression-test/data/datatype_p0/decimalv3/test_predicate.out b/regression-test/data/datatype_p0/decimalv3/test_predicate.out
new file mode 100644
index 0000000000..99787dfd4b
--- /dev/null
+++ b/regression-test/data/datatype_p0/decimalv3/test_predicate.out
@@ -0,0 +1,13 @@
+-- This file is automatically generated. You should know what you did if you want to edit this
+-- !select1 --
+true
+
+-- !select2 --
+1
+1
+1
+
+-- !select3 --
+1.200000000000000000 1.200000000000000000 1.300000000000000000
+1.500000000000000000 1.200000000000000000 1.300000000000000000
+
diff --git a/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions.out b/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions.out
index e0bb00b0e0..d8fb5e32af 100644
--- a/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions.out
+++ b/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions.out
@@ -589,3 +589,15 @@
10005 [10005, NULL, NULL] [NULL]
10006 [60002, 60002, 60003, NULL, 60005] [NULL]
+-- !select_array_datetimev2_1 --
+1 [2023-01-19 18:11:11.111100, 2023-01-19 18:22:22.222200, 2023-01-19 18:33:33.333300] [2023-01-19 18:22:22.222200, 2023-01-19 18:33:33.333300, 2023-01-19 18:44:44.444400] [2023-01-19 18:11:11.111111, 2023-01-19 18:22:22.222222, 2023-01-19 18:33:33.333333]
+
+-- !select_array_datetimev2_2 --
+[2023-01-19 18:11:11.111100, 2023-01-19 18:22:22.222200, 2023-01-19 18:33:33.333300]
+
+-- !select_array_datetimev2_3 --
+[2023-01-19 18:22:22.222200, 2023-01-19 18:33:33.333300, 2023-01-19 18:44:44.444400]
+
+-- !select_array_datetimev2_4 --
+[2023-01-19 18:11:11.111111, 2023-01-19 18:22:22.222222, 2023-01-19 18:33:33.333333]
+
diff --git a/regression-test/data/query_p0/sql_functions/math_functions/test_running_difference.out b/regression-test/data/query_p0/sql_functions/math_functions/test_running_difference.out
index 4399a055d8..fc0dbc5fe1 100644
--- a/regression-test/data/query_p0/sql_functions/math_functions/test_running_difference.out
+++ b/regression-test/data/query_p0/sql_functions/math_functions/test_running_difference.out
@@ -33,12 +33,12 @@
2022-11-08 8
-- !test_running_difference_7 --
-2022-03-12T10:41 0.0
-2022-03-12T10:41:02 2.0
-2022-03-12T10:41:03 1.0
-2022-03-12T10:41:03 0.0
-2022-03-12T10:42:01 58.0
-2022-03-12T11:05:04 1383.0
+2022-03-12T10:41 00:00:00
+2022-03-12T10:41:02 00:00:02
+2022-03-12T10:41:03 00:00:01
+2022-03-12T10:41:03 00:00:00
+2022-03-12T10:42:01 00:00:58
+2022-03-12T11:05:04 00:23:03
-- !test_running_difference_8 --
\N \N
diff --git a/regression-test/suites/correctness_p0/test_pushdown_constant.groovy b/regression-test/suites/correctness_p0/test_pushdown_constant.groovy
index d392781373..dd38166017 100644
--- a/regression-test/suites/correctness_p0/test_pushdown_constant.groovy
+++ b/regression-test/suites/correctness_p0/test_pushdown_constant.groovy
@@ -16,9 +16,10 @@
// under the License.
suite("test_pushdown_constant") {
- sql """ DROP TABLE IF EXISTS `test_pushdown_constant` """
+ def tblName = "test_pushdown_constant"
+ sql """ DROP TABLE IF EXISTS `${tblName}` """
sql """
- CREATE TABLE IF NOT EXISTS `test_pushdown_constant` (
+ CREATE TABLE IF NOT EXISTS `${tblName}` (
`id` int
) ENGINE=OLAP
AGGREGATE KEY(`id`)
@@ -31,11 +32,33 @@ suite("test_pushdown_constant") {
);
"""
sql """
- insert into test_pushdown_constant values(1);
+ insert into ${tblName} values(1);
"""
qt_sql """
- select 1 from test_pushdown_constant where BITMAP_MAX( BITMAP_AND(BITMAP_EMPTY(), coalesce(NULL, bitmap_empty()))) is NULL;
+ select 1 from ${tblName} where BITMAP_MAX( BITMAP_AND(BITMAP_EMPTY(), coalesce(NULL, bitmap_empty()))) is NULL;
"""
+ sql """ DROP TABLE IF EXISTS `${tblName}` """
+
+ sql """
+ CREATE TABLE IF NOT EXISTS `${tblName}` (
+ `c1` date,
+ `c2` datetime
+ ) ENGINE=OLAP
+ COMMENT "OLAP"
+ DISTRIBUTED BY HASH(`c1`) BUCKETS 1
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1",
+ "in_memory" = "false",
+ "storage_format" = "V2"
+ );
+ """
+ sql """
+ insert into ${tblName} values('20220101', '20220101111111');
+ """
+
+ qt_select_all """ select * from ${tblName} """
+ qt_predicate """ select * from ${tblName} where cast(c2 as date) = date '2022-01-01'"""
+ sql """ DROP TABLE IF EXISTS `${tblName}` """
}
diff --git a/regression-test/suites/datatype_p0/decimalv3/test_data/test.csv b/regression-test/suites/datatype_p0/decimalv3/test_data/test.csv
new file mode 100644
index 0000000000..667fe01c08
--- /dev/null
+++ b/regression-test/suites/datatype_p0/decimalv3/test_data/test.csv
@@ -0,0 +1,3 @@
+0.00013225356500247968
+0.00039016009826936000
+0.00013503989119010048
diff --git a/regression-test/suites/datatype_p0/decimalv3/test_load.groovy b/regression-test/suites/datatype_p0/decimalv3/test_load.groovy
new file mode 100644
index 0000000000..ba7b04ad94
--- /dev/null
+++ b/regression-test/suites/datatype_p0/decimalv3/test_load.groovy
@@ -0,0 +1,58 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+import org.codehaus.groovy.runtime.IOGroovyMethods
+
+import java.nio.charset.StandardCharsets
+import java.nio.file.Files
+import java.nio.file.Paths
+
+suite("test_load") {
+ def dbName = "test_load"
+ sql "CREATE DATABASE IF NOT EXISTS ${dbName}"
+ sql "USE $dbName"
+
+ def tableName = "test_decimal_load"
+ try {
+ sql """ DROP TABLE IF EXISTS ${tableName} """
+ sql """
+ CREATE TABLE IF NOT EXISTS ${tableName} (
+ `a` decimalv3(38,18)
+ ) ENGINE=OLAP
+ DUPLICATE KEY(`a`)
+ COMMENT 'OLAP'
+ DISTRIBUTED BY HASH(`a`) BUCKETS 1
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1"
+ );
+ """
+
+ StringBuilder commandBuilder = new StringBuilder()
+ commandBuilder.append("""curl --location-trusted -u ${context.config.feHttpUser}:${context.config.feHttpPassword}""")
+ commandBuilder.append(""" -H format:csv -T ${context.file.parent}/test_data/test.csv http://${context.config.feHttpAddress}/api/""" + dbName + "/" + tableName + "/_stream_load")
+ command = commandBuilder.toString()
+ process = command.execute()
+ code = process.waitFor()
+ err = IOGroovyMethods.getText(new BufferedReader(new InputStreamReader(process.getErrorStream())))
+ out = process.getText()
+ logger.info("Run command: command=" + command + ",code=" + code + ", out=" + out + ", err=" + err)
+ assertEquals(code, 0)
+ qt_select_default """ SELECT * FROM ${tableName} t ORDER BY a; """
+ } finally {
+ try_sql("DROP TABLE IF EXISTS ${tableName}")
+ }
+}
diff --git a/regression-test/suites/datatype_p0/decimalv3/test_overflow.groovy b/regression-test/suites/datatype_p0/decimalv3/test_overflow.groovy
new file mode 100644
index 0000000000..01de2ea498
--- /dev/null
+++ b/regression-test/suites/datatype_p0/decimalv3/test_overflow.groovy
@@ -0,0 +1,56 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_overflow") {
+
+ def table1 = "test_overflow"
+
+ sql "drop table if exists ${table1}"
+
+ sql """
+ CREATE TABLE IF NOT EXISTS test_overflow (
+ `k1` decimalv3(38, 18) NULL COMMENT "",
+ `k2` decimalv3(38, 18) NULL COMMENT "",
+ `k3` decimalv3(38, 18) NULL COMMENT "",
+ `v1` decimalv3(38, 37) NULL COMMENT "",
+ `v2` decimalv3(38, 37) NULL COMMENT "",
+ `v3` decimalv3(38, 37) NULL COMMENT "",
+ `v4` INT NULL COMMENT ""
+ ) ENGINE=OLAP
+ COMMENT "OLAP"
+ DISTRIBUTED BY HASH(`k1`, `k2`, `k3`) BUCKETS 8
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1",
+ "in_memory" = "false",
+ "storage_format" = "V2"
+ )
+ """
+
+ sql """insert into test_overflow values(11111111111111111111.1,11111111111111111111.2,11111111111111111111.3, 1.1,1.2,1.3,9)
+ """
+ qt_select_all "select * from test_overflow order by k1"
+
+ sql " SET check_overflow_for_decimal = true; "
+ qt_select_check_overflow1 "select k1 * k2, k1 * k3, k1 * k2 * k3, k1 * v4, k1*50 from test_overflow;"
+ qt_select_check_overflow2 "select v1, k1*10, v1 +k1*10 from test_overflow"
+ qt_select_check_overflow3 "select `k1`, cast (`k1` as DECIMALV3(38, 36)) from test_overflow;"
+
+ sql " SET check_overflow_for_decimal = false; "
+ qt_select_not_check_overflow1 "select k1 * k2, k1 * k3, k1 * k2 * k3, k1 * v4, k1*50 from test_overflow;"
+ qt_select_not_check_overflow2 "select v1, k1*10, v1 +k1*10 from test_overflow"
+ sql "drop table if exists ${table1}"
+}
diff --git a/regression-test/suites/datatype_p0/decimalv3/test_predicate.groovy b/regression-test/suites/datatype_p0/decimalv3/test_predicate.groovy
new file mode 100644
index 0000000000..62c28fa928
--- /dev/null
+++ b/regression-test/suites/datatype_p0/decimalv3/test_predicate.groovy
@@ -0,0 +1,47 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_predicate") {
+ def table1 = "test_predicate"
+
+ sql "drop table if exists ${table1}"
+
+ sql """
+ CREATE TABLE IF NOT EXISTS `${table1}` (
+ `k1` decimalv3(38, 18) NULL COMMENT "",
+ `k2` decimalv3(38, 18) NULL COMMENT "",
+ `k3` decimalv3(38, 18) NULL COMMENT ""
+ ) ENGINE=OLAP
+ COMMENT "OLAP"
+ DISTRIBUTED BY HASH(`k1`, `k2`, `k3`) BUCKETS 8
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1",
+ "in_memory" = "false",
+ "storage_format" = "V2"
+ )
+ """
+
+ sql """insert into ${table1} values(1.1,1.2,1.3),
+ (1.2,1.2,1.3),
+ (1.5,1.2,1.3)
+ """
+ qt_select1 "SELECT CAST((CASE WHEN (TRUE IS NOT NULL) THEN '1.2' ELSE '1.2' END) AS FLOAT) = CAST(1.2 AS decimal(2,1))"
+
+ qt_select2 "SELECT 1 FROM ${table1} WHERE CAST((CASE WHEN (TRUE IS NOT NULL) THEN '1.2' ELSE '1.2' END) AS FLOAT) = CAST(1.2 AS decimal(2,1));"
+ qt_select3 "SELECT * FROM ${table1} WHERE k1 != 1.1 ORDER BY k1"
+ sql "drop table if exists ${table1}"
+}
diff --git a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions.groovy b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions.groovy
index 1514692651..4b4203a461 100644
--- a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions.groovy
+++ b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions.groovy
@@ -153,4 +153,31 @@ suite("test_array_functions") {
qt_select_union "select class_id, student_ids, array_union(student_ids,[1,2,3]) from ${tableName3} order by class_id;"
qt_select_except "select class_id, student_ids, array_except(student_ids,[1,2,3]) from ${tableName3} order by class_id;"
qt_select_intersect "select class_id, student_ids, array_intersect(student_ids,[1,2,3,null]) from ${tableName3} order by class_id;"
+
+ def tableName4 = "tbl_test_array_datetimev2_functions"
+
+ sql """DROP TABLE IF EXISTS ${tableName4}"""
+ sql """
+ CREATE TABLE IF NOT EXISTS ${tableName4} (
+ `k1` int COMMENT "",
+ `k2` ARRAY<datetimev2(4)> COMMENT "",
+ `k3` ARRAY<datetimev2(4)> COMMENT "",
+ `k4` ARRAY<datetimev2(6)> COMMENT ""
+ ) ENGINE=OLAP
+ DUPLICATE KEY(`k1`)
+ DISTRIBUTED BY HASH(`k1`) BUCKETS 1
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1",
+ "storage_format" = "V2"
+ )
+ """
+ sql """ INSERT INTO ${tableName4} VALUES(1,
+ ["2023-01-19 18:11:11.1111","2023-01-19 18:22:22.2222","2023-01-19 18:33:33.3333"],
+ ["2023-01-19 18:22:22.2222","2023-01-19 18:33:33.3333","2023-01-19 18:44:44.4444"],
+ ["2023-01-19 18:11:11.111111","2023-01-19 18:22:22.222222","2023-01-19 18:33:33.333333"]) """
+
+ qt_select_array_datetimev2_1 "SELECT * FROM ${tableName4}"
+ qt_select_array_datetimev2_2 "SELECT if(1,k2,k3) FROM ${tableName4}"
+ qt_select_array_datetimev2_3 "SELECT if(0,k2,k3) FROM ${tableName4}"
+ qt_select_array_datetimev2_4 "SELECT if(0,k2,k4) FROM ${tableName4}"
}
diff --git a/regression-test/suites/query_p0/sql_functions/datetime_functions/test_date_function.groovy b/regression-test/suites/query_p0/sql_functions/datetime_functions/test_date_function.groovy
index d596dbba2a..e588c68569 100644
--- a/regression-test/suites/query_p0/sql_functions/datetime_functions/test_date_function.groovy
+++ b/regression-test/suites/query_p0/sql_functions/datetime_functions/test_date_function.groovy
@@ -569,7 +569,7 @@ suite("test_date_function") {
('2022-01-01', '2022-01-01', '2022-01-01 00:00:00', '2022-01-01 00:00:00'),
('2000-02-01', '2000-02-01', '2000-02-01 00:00:00', '2000-02-01 00:00:00.123'),
('2022-02-29', '2022-02-29', '2022-02-29 00:00:00', '2022-02-29 00:00:00'),
- ('2022-02-28', '2022-02-28', '2022-02-28 23:59:59', '2022-02-28 23:59:59');"""
+ ('2022-02-28', '2022-02-28', '2022-02-28T23:59:59', '2022-02-28T23:59:59');"""
qt_sql """
select last_day(birth), last_day(birth1),
last_day(birth2), last_day(birth3)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org