You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2022/12/18 12:36:00 UTC

[doris] branch branch-1.2-lts updated: [cherrypick](datev2-decimalv3) refine function expr of datev2 & compute accurate round value by decimal (#15103)

This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch branch-1.2-lts
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-1.2-lts by this push:
     new b6738c1938 [cherrypick](datev2-decimalv3) refine function expr of datev2 & compute accurate round value by decimal (#15103)
b6738c1938 is described below

commit b6738c193868d885dbdf0418f6e18a52afe09548
Author: Gabriel <ga...@gmail.com>
AuthorDate: Sun Dec 18 20:35:52 2022 +0800

    [cherrypick](datev2-decimalv3) refine function expr of datev2 & compute accurate round value by decimal (#15103)
---
 be/src/vec/functions/math.cpp                      | 119 +++--
 be/src/vec/functions/round.h                       | 556 +++++++++++++++++++++
 be/test/vec/function/function_math_test.cpp        |  35 +-
 .../apache/doris/analysis/FunctionCallExpr.java    | 147 +++---
 .../java/org/apache/doris/catalog/Function.java    |   4 -
 .../apache/doris/planner/ConstantExpressTest.java  |   7 -
 gensrc/script/doris_builtins_functions.py          |  60 ++-
 .../sql_functions/math_functions/test_round.out    |  12 +-
 .../sql_functions/math_functions/test_round.groovy |  22 +-
 9 files changed, 777 insertions(+), 185 deletions(-)

diff --git a/be/src/vec/functions/math.cpp b/be/src/vec/functions/math.cpp
index 417f576a52..67552b0f18 100644
--- a/be/src/vec/functions/math.cpp
+++ b/be/src/vec/functions/math.cpp
@@ -24,6 +24,7 @@
 #include "vec/functions/function_string.h"
 #include "vec/functions/function_totype.h"
 #include "vec/functions/function_unary_arithmetic.h"
+#include "vec/functions/round.h"
 #include "vec/functions/simple_function_factory.h"
 
 namespace doris::vectorized {
@@ -184,11 +185,6 @@ struct LogImpl {
 };
 using FunctionLog = FunctionBinaryArithmetic<LogImpl, LogName, true>;
 
-struct CeilName {
-    static constexpr auto name = "ceil";
-};
-using FunctionCeil = FunctionMathUnary<UnaryFunctionVectorized<CeilName, std::ceil, DataTypeInt64>>;
-
 template <typename A>
 struct SignImpl {
     using ResultType = Int8;
@@ -276,12 +272,6 @@ struct TanName {
 };
 using FunctionTan = FunctionMathUnary<UnaryFunctionVectorized<TanName, std::tan>>;
 
-struct FloorName {
-    static constexpr auto name = "floor";
-};
-using FunctionFloor =
-        FunctionMathUnary<UnaryFunctionVectorized<FloorName, std::floor, DataTypeInt64>>;
-
 template <typename A>
 struct RadiansImpl {
     using ResultType = A;
@@ -346,30 +336,6 @@ struct BinImpl {
 
 using FunctionBin = FunctionUnaryToType<BinImpl, NameBin>;
 
-struct RoundName {
-    static constexpr auto name = "round";
-};
-
-/// round(double)-->int64
-/// key_str:roundFloat64
-template <typename Name>
-struct RoundOneImpl {
-    using Type = DataTypeInt64;
-    static constexpr auto name = RoundName::name;
-    static constexpr auto rows_per_iteration = 1;
-    static constexpr bool always_returns_float64 = false;
-
-    static DataTypes get_variadic_argument_types() {
-        return {std::make_shared<vectorized::DataTypeFloat64>()};
-    }
-
-    template <typename T, typename U>
-    static void execute(const T* src, U* dst) {
-        dst[0] = static_cast<Int64>(std::round(static_cast<Float64>(src[0])));
-    }
-};
-using FunctionRoundOne = FunctionMathUnary<RoundOneImpl<RoundName>>;
-
 template <typename A, typename B>
 struct PowImpl {
     using ResultType = double;
@@ -386,52 +352,83 @@ struct PowName {
 };
 using FunctionPow = FunctionBinaryArithmetic<PowImpl, PowName, false>;
 
-template <typename A, typename B>
-struct TruncateImpl {
-    using ResultType = double;
-    static const constexpr bool allow_decimal = false;
-
-    template <typename type>
-    static inline double apply(A a, B b) {
-        /// Next everywhere, static_cast - so that there is no wrong result in expressions of the form Int64 c = UInt32(a) * Int32(-1).
-        return static_cast<Float64>(
-                my_double_round(static_cast<Float64>(a), static_cast<Int32>(b), false, true));
-    }
-};
 struct TruncateName {
     static constexpr auto name = "truncate";
 };
-using FunctionTruncate = FunctionBinaryArithmetic<TruncateImpl, TruncateName, false>;
+
+struct CeilName {
+    static constexpr auto name = "ceil";
+};
+
+struct FloorName {
+    static constexpr auto name = "floor";
+};
+
+struct RoundName {
+    static constexpr auto name = "round";
+};
 
 /// round(double,int32)-->double
 /// key_str:roundFloat64Int32
-template <typename A, typename B>
-struct RoundTwoImpl {
-    using ResultType = double;
-    static const constexpr bool allow_decimal = false;
+template <typename Name>
+struct DoubleRoundTwoImpl {
+    static constexpr auto name = Name::name;
 
     static DataTypes get_variadic_argument_types() {
         return {std::make_shared<vectorized::DataTypeFloat64>(),
                 std::make_shared<vectorized::DataTypeInt32>()};
     }
+};
 
-    template <typename type>
-    static inline double apply(A a, B b) {
-        /// Next everywhere, static_cast - so that there is no wrong result in expressions of the form Int64 c = UInt32(a) * Int32(-1).
-        return static_cast<Float64>(
-                my_double_round(static_cast<Float64>(a), static_cast<Int32>(b), false, false));
+template <typename Name>
+struct DoubleRoundOneImpl {
+    static constexpr auto name = Name::name;
+
+    static DataTypes get_variadic_argument_types() {
+        return {std::make_shared<vectorized::DataTypeFloat64>()};
+    }
+};
+
+template <typename Name>
+struct DecimalRoundTwoImpl {
+    static constexpr auto name = Name::name;
+
+    static DataTypes get_variadic_argument_types() {
+        return {std::make_shared<vectorized::DataTypeDecimal<Decimal32>>(9, 0),
+                std::make_shared<vectorized::DataTypeInt32>()};
+    }
+};
+
+template <typename Name>
+struct DecimalRoundOneImpl {
+    static constexpr auto name = Name::name;
+
+    static DataTypes get_variadic_argument_types() {
+        return {std::make_shared<vectorized::DataTypeDecimal<Decimal32>>(9, 0)};
     }
 };
-using FunctionRoundTwo = FunctionBinaryArithmetic<RoundTwoImpl, RoundName, false>;
 
 // TODO: Now math may cause one thread compile time too long, because the function in math
 // so mush. Split it to speed up compile time in the future
 void register_function_math(SimpleFunctionFactory& factory) {
+#define REGISTER_ROUND_FUNCTIONS(IMPL)                                                        \
+    factory.register_function<                                                                \
+            FunctionRounding<IMPL<RoundName>, RoundingMode::Round, TieBreakingMode::Auto>>(); \
+    factory.register_function<                                                                \
+            FunctionRounding<IMPL<FloorName>, RoundingMode::Floor, TieBreakingMode::Auto>>(); \
+    factory.register_function<                                                                \
+            FunctionRounding<IMPL<CeilName>, RoundingMode::Ceil, TieBreakingMode::Auto>>();   \
+    factory.register_function<                                                                \
+            FunctionRounding<IMPL<TruncateName>, RoundingMode::Trunc, TieBreakingMode::Auto>>();
+
+    REGISTER_ROUND_FUNCTIONS(DecimalRoundOneImpl)
+    REGISTER_ROUND_FUNCTIONS(DecimalRoundTwoImpl)
+    REGISTER_ROUND_FUNCTIONS(DoubleRoundOneImpl)
+    REGISTER_ROUND_FUNCTIONS(DoubleRoundTwoImpl)
     factory.register_function<FunctionAcos>();
     factory.register_function<FunctionAsin>();
     factory.register_function<FunctionAtan>();
     factory.register_function<FunctionCos>();
-    factory.register_function<FunctionCeil>();
     factory.register_alias("ceil", "dceil");
     factory.register_alias("ceil", "ceiling");
     factory.register_function<FunctionE>();
@@ -451,17 +448,13 @@ void register_function_math(SimpleFunctionFactory& factory) {
     factory.register_alias("sqrt", "dsqrt");
     factory.register_function<FunctionCbrt>();
     factory.register_function<FunctionTan>();
-    factory.register_function<FunctionFloor>();
     factory.register_alias("floor", "dfloor");
-    factory.register_function<FunctionRoundOne>();
-    factory.register_function<FunctionRoundTwo>();
     factory.register_function<FunctionPow>();
     factory.register_alias("pow", "power");
     factory.register_alias("pow", "dpow");
     factory.register_alias("pow", "fpow");
     factory.register_function<FunctionExp>();
     factory.register_alias("exp", "dexp");
-    factory.register_function<FunctionTruncate>();
     factory.register_function<FunctionRadians>();
     factory.register_function<FunctionDegrees>();
     factory.register_function<FunctionBin>();
diff --git a/be/src/vec/functions/round.h b/be/src/vec/functions/round.h
new file mode 100644
index 0000000000..679af9b4b2
--- /dev/null
+++ b/be/src/vec/functions/round.h
@@ -0,0 +1,556 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+// This file is copied from
+// https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/FunctionRound.h
+// and modified by Doris
+
+#pragma once
+
+#ifdef __SSE4_1__
+#include <smmintrin.h>
+#else
+#include <fenv.h>
+#endif
+
+#include "vec/columns/column.h"
+#include "vec/columns/column_decimal.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_number.h"
+
+namespace doris::vectorized {
+
+enum class ScaleMode {
+    Positive, // round to a number with N decimal places after the decimal point
+    Negative, // round to an integer with N zero characters
+    Zero,     // round to an integer
+};
+
+enum class RoundingMode {
+#ifdef __SSE4_1__
+    Round = _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC,
+    Floor = _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC,
+    Ceil = _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC,
+    Trunc = _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC,
+#else
+    Round = 8, /// Values are correspond to above just in case.
+    Floor = 9,
+    Ceil = 10,
+    Trunc = 11,
+#endif
+};
+
+enum class TieBreakingMode {
+    Auto,    // use banker's rounding for floating point numbers, round up otherwise
+    Bankers, // use banker's rounding
+};
+
+template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode,
+          TieBreakingMode tie_breaking_mode>
+struct IntegerRoundingComputation {
+    static const size_t data_count = 1;
+
+    static size_t prepare(size_t scale) { return scale; }
+
+    /// Integer overflow is Ok.
+    static ALWAYS_INLINE T compute_impl(T x, T scale) {
+        switch (rounding_mode) {
+        case RoundingMode::Trunc: {
+            return x / scale * scale;
+        }
+        case RoundingMode::Floor: {
+            if (x < 0) {
+                x -= scale - 1;
+            }
+            return x / scale * scale;
+        }
+        case RoundingMode::Ceil: {
+            if (x >= 0) {
+                x += scale - 1;
+            }
+            return x / scale * scale;
+        }
+        case RoundingMode::Round: {
+            if (x < 0) {
+                x -= scale;
+            }
+            switch (tie_breaking_mode) {
+            case TieBreakingMode::Auto: {
+                x = (x + scale / 2) / scale * scale;
+                break;
+            }
+            case TieBreakingMode::Bankers: {
+                T quotient = (x + scale / 2) / scale;
+                if (quotient * scale == x + scale / 2) {
+                    // round half to even
+                    x = ((quotient + (x < 0)) & ~1) * scale;
+                } else {
+                    // round the others as usual
+                    x = quotient * scale;
+                }
+                break;
+            }
+            }
+            return x;
+        }
+        }
+
+        __builtin_unreachable();
+    }
+
+    static ALWAYS_INLINE T compute(T x, T scale) {
+        switch (scale_mode) {
+        case ScaleMode::Zero:
+        case ScaleMode::Positive:
+            return x;
+        case ScaleMode::Negative:
+            return compute_impl(x, scale);
+        }
+
+        __builtin_unreachable();
+    }
+
+    static ALWAYS_INLINE void compute(const T* __restrict in, size_t scale, T* __restrict out) {
+        if constexpr (sizeof(T) <= sizeof(scale) && scale_mode == ScaleMode::Negative) {
+            if (scale > size_t(std::numeric_limits<T>::max())) {
+                *out = 0;
+                return;
+            }
+        }
+        *out = compute(*in, scale);
+    }
+};
+
+template <typename T, RoundingMode rounding_mode, TieBreakingMode tie_breaking_mode>
+class DecimalRoundingImpl {
+private:
+    using NativeType = typename T::NativeType;
+    using Op = IntegerRoundingComputation<NativeType, rounding_mode, ScaleMode::Negative,
+                                          tie_breaking_mode>;
+    using Container = typename ColumnDecimal<T>::Container;
+
+public:
+    static NO_INLINE void apply(const Container& in, UInt32 in_scale, Container& out,
+                                Int16 scale_arg) {
+        scale_arg = in_scale - scale_arg;
+        if (scale_arg > 0) {
+            size_t scale = int_exp10(scale_arg);
+
+            const NativeType* __restrict p_in = reinterpret_cast<const NativeType*>(in.data());
+            const NativeType* end_in = reinterpret_cast<const NativeType*>(in.data()) + in.size();
+            NativeType* __restrict p_out = reinterpret_cast<NativeType*>(out.data());
+
+            while (p_in < end_in) {
+                Op::compute(p_in, scale, p_out);
+                ++p_in;
+                ++p_out;
+            }
+        } else {
+            memcpy(out.data(), in.data(), in.size() * sizeof(T));
+        }
+    }
+};
+
+#ifdef __SSE4_1__
+
+template <typename T>
+class BaseFloatRoundingComputation;
+
+template <>
+class BaseFloatRoundingComputation<Float32> {
+public:
+    using ScalarType = Float32;
+    using VectorType = __m128;
+    static const size_t data_count = 4;
+
+    static VectorType load(const ScalarType* in) { return _mm_loadu_ps(in); }
+    static VectorType load1(const ScalarType in) { return _mm_load1_ps(&in); }
+    static void store(ScalarType* out, VectorType val) { _mm_storeu_ps(out, val); }
+    static VectorType multiply(VectorType val, VectorType scale) { return _mm_mul_ps(val, scale); }
+    static VectorType divide(VectorType val, VectorType scale) { return _mm_div_ps(val, scale); }
+    template <RoundingMode mode>
+    static VectorType apply(VectorType val) {
+        return _mm_round_ps(val, int(mode));
+    }
+
+    static VectorType prepare(size_t scale) { return load1(scale); }
+};
+
+template <>
+class BaseFloatRoundingComputation<Float64> {
+public:
+    using ScalarType = Float64;
+    using VectorType = __m128d;
+    static const size_t data_count = 2;
+
+    static VectorType load(const ScalarType* in) { return _mm_loadu_pd(in); }
+    static VectorType load1(const ScalarType in) { return _mm_load1_pd(&in); }
+    static void store(ScalarType* out, VectorType val) { _mm_storeu_pd(out, val); }
+    static VectorType multiply(VectorType val, VectorType scale) { return _mm_mul_pd(val, scale); }
+    static VectorType divide(VectorType val, VectorType scale) { return _mm_div_pd(val, scale); }
+    template <RoundingMode mode>
+    static VectorType apply(VectorType val) {
+        return _mm_round_pd(val, int(mode));
+    }
+
+    static VectorType prepare(size_t scale) { return load1(scale); }
+};
+
+#else
+
+/// Implementation for ARM. Not vectorized.
+
+inline float roundWithMode(float x, RoundingMode mode) {
+    switch (mode) {
+    case RoundingMode::Round:
+        return nearbyintf(x);
+    case RoundingMode::Floor:
+        return floorf(x);
+    case RoundingMode::Ceil:
+        return ceilf(x);
+    case RoundingMode::Trunc:
+        return truncf(x);
+    }
+
+    __builtin_unreachable();
+}
+
+inline double roundWithMode(double x, RoundingMode mode) {
+    switch (mode) {
+    case RoundingMode::Round:
+        return nearbyint(x);
+    case RoundingMode::Floor:
+        return floor(x);
+    case RoundingMode::Ceil:
+        return ceil(x);
+    case RoundingMode::Trunc:
+        return trunc(x);
+    }
+
+    __builtin_unreachable();
+}
+
+template <typename T>
+class BaseFloatRoundingComputation {
+public:
+    using ScalarType = T;
+    using VectorType = T;
+    static const size_t data_count = 1;
+
+    static VectorType load(const ScalarType* in) { return *in; }
+    static VectorType load1(const ScalarType in) { return in; }
+    static VectorType store(ScalarType* out, ScalarType val) { return *out = val; }
+    static VectorType multiply(VectorType val, VectorType scale) { return val * scale; }
+    static VectorType divide(VectorType val, VectorType scale) { return val / scale; }
+    template <RoundingMode mode>
+    static VectorType apply(VectorType val) {
+        return roundWithMode(val, mode);
+    }
+
+    static VectorType prepare(size_t scale) { return load1(scale); }
+};
+
+#endif
+
+/** Implementation of low-level round-off functions for floating-point values.
+  */
+template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
+class FloatRoundingComputation : public BaseFloatRoundingComputation<T> {
+    using Base = BaseFloatRoundingComputation<T>;
+
+public:
+    static inline void compute(const T* __restrict in, const typename Base::VectorType& scale,
+                               T* __restrict out) {
+        auto val = Base::load(in);
+
+        if (scale_mode == ScaleMode::Positive) {
+            val = Base::multiply(val, scale);
+        } else if (scale_mode == ScaleMode::Negative) {
+            val = Base::divide(val, scale);
+        }
+
+        val = Base::template apply<rounding_mode>(val);
+
+        if (scale_mode == ScaleMode::Positive) {
+            val = Base::divide(val, scale);
+        } else if (scale_mode == ScaleMode::Negative) {
+            val = Base::multiply(val, scale);
+        }
+
+        Base::store(out, val);
+    }
+};
+
+/** Implementing high-level rounding functions.
+  */
+template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
+struct FloatRoundingImpl {
+private:
+    static_assert(!IsDecimalNumber<T>);
+
+    using Op = FloatRoundingComputation<T, rounding_mode, scale_mode>;
+    using Data = std::array<T, Op::data_count>;
+    using ColumnType = ColumnVector<T>;
+    using Container = typename ColumnType::Container;
+
+public:
+    static NO_INLINE void apply(const Container& in, size_t scale, Container& out) {
+        auto mm_scale = Op::prepare(scale);
+
+        const size_t data_count = std::tuple_size<Data>();
+
+        const T* end_in = in.data() + in.size();
+        const T* limit = in.data() + in.size() / data_count * data_count;
+
+        const T* __restrict p_in = in.data();
+        T* __restrict p_out = out.data();
+
+        while (p_in < limit) {
+            Op::compute(p_in, mm_scale, p_out);
+            p_in += data_count;
+            p_out += data_count;
+        }
+
+        if (p_in < end_in) {
+            Data tmp_src {{}};
+            Data tmp_dst;
+
+            size_t tail_size_bytes = (end_in - p_in) * sizeof(*p_in);
+
+            memcpy(&tmp_src, p_in, tail_size_bytes);
+            Op::compute(reinterpret_cast<T*>(&tmp_src), mm_scale, reinterpret_cast<T*>(&tmp_dst));
+            memcpy(p_out, &tmp_dst, tail_size_bytes);
+        }
+    }
+};
+
+template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode,
+          TieBreakingMode tie_breaking_mode>
+struct IntegerRoundingImpl {
+private:
+    using Op = IntegerRoundingComputation<T, rounding_mode, scale_mode, tie_breaking_mode>;
+    using Container = typename ColumnVector<T>::Container;
+
+public:
+    template <size_t scale>
+    static NO_INLINE void applyImpl(const Container& in, Container& out) {
+        const T* end_in = in.data() + in.size();
+
+        const T* __restrict p_in = in.data();
+        T* __restrict p_out = out.data();
+
+        while (p_in < end_in) {
+            Op::compute(p_in, scale, p_out);
+            ++p_in;
+            ++p_out;
+        }
+    }
+
+    static NO_INLINE void apply(const Container& in, size_t scale, Container& out) {
+        /// Manual function cloning for compiler to generate integer division by constant.
+        switch (scale) {
+        case 1ULL:
+            return applyImpl<1ULL>(in, out);
+        case 10ULL:
+            return applyImpl<10ULL>(in, out);
+        case 100ULL:
+            return applyImpl<100ULL>(in, out);
+        case 1000ULL:
+            return applyImpl<1000ULL>(in, out);
+        case 10000ULL:
+            return applyImpl<10000ULL>(in, out);
+        case 100000ULL:
+            return applyImpl<100000ULL>(in, out);
+        case 1000000ULL:
+            return applyImpl<1000000ULL>(in, out);
+        case 10000000ULL:
+            return applyImpl<10000000ULL>(in, out);
+        case 100000000ULL:
+            return applyImpl<100000000ULL>(in, out);
+        case 1000000000ULL:
+            return applyImpl<1000000000ULL>(in, out);
+        case 10000000000ULL:
+            return applyImpl<10000000000ULL>(in, out);
+        case 100000000000ULL:
+            return applyImpl<100000000000ULL>(in, out);
+        case 1000000000000ULL:
+            return applyImpl<1000000000000ULL>(in, out);
+        case 10000000000000ULL:
+            return applyImpl<10000000000000ULL>(in, out);
+        case 100000000000000ULL:
+            return applyImpl<100000000000000ULL>(in, out);
+        case 1000000000000000ULL:
+            return applyImpl<1000000000000000ULL>(in, out);
+        case 10000000000000000ULL:
+            return applyImpl<10000000000000000ULL>(in, out);
+        case 100000000000000000ULL:
+            return applyImpl<100000000000000000ULL>(in, out);
+        case 1000000000000000000ULL:
+            return applyImpl<1000000000000000000ULL>(in, out);
+        case 10000000000000000000ULL:
+            return applyImpl<10000000000000000000ULL>(in, out);
+        default:
+            __builtin_unreachable();
+        }
+    }
+};
+
+/** Select the appropriate processing algorithm depending on the scale.
+  */
+template <typename T, RoundingMode rounding_mode, TieBreakingMode tie_breaking_mode>
+struct Dispatcher {
+    template <ScaleMode scale_mode>
+    using FunctionRoundingImpl = std::conditional_t<
+            IsDecimalNumber<T>, DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>,
+            std::conditional_t<
+                    std::is_floating_point_v<T>, FloatRoundingImpl<T, rounding_mode, scale_mode>,
+                    IntegerRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>>>;
+
+    static ColumnPtr apply(const IColumn* col_general, Int16 scale_arg) {
+        if constexpr (IsNumber<T>) {
+            const auto* const col = check_and_get_column<ColumnVector<T>>(col_general);
+            auto col_res = ColumnVector<T>::create();
+
+            typename ColumnVector<T>::Container& vec_res = col_res->get_data();
+            vec_res.resize(col->get_data().size());
+
+            if (!vec_res.empty()) {
+                if (scale_arg == 0) {
+                    size_t scale = 1;
+                    FunctionRoundingImpl<ScaleMode::Zero>::apply(col->get_data(), scale, vec_res);
+                } else if (scale_arg > 0) {
+                    size_t scale = int_exp10(scale_arg);
+                    FunctionRoundingImpl<ScaleMode::Positive>::apply(col->get_data(), scale,
+                                                                     vec_res);
+                } else {
+                    size_t scale = int_exp10(-scale_arg);
+                    FunctionRoundingImpl<ScaleMode::Negative>::apply(col->get_data(), scale,
+                                                                     vec_res);
+                }
+            }
+
+            return col_res;
+        } else if constexpr (IsDecimalNumber<T>) {
+            const auto* const decimal_col = check_and_get_column<ColumnDecimal<T>>(col_general);
+            const auto& vec_src = decimal_col->get_data();
+
+            auto col_res = ColumnDecimal<T>::create(vec_src.size(), decimal_col->get_scale());
+            auto& vec_res = col_res->get_data();
+
+            if (!vec_res.empty()) {
+                FunctionRoundingImpl<ScaleMode::Negative>::apply(
+                        decimal_col->get_data(), decimal_col->get_scale(), vec_res, scale_arg);
+            }
+
+            return col_res;
+        } else {
+            __builtin_unreachable();
+            return nullptr;
+        }
+    }
+};
+
+template <typename Impl, RoundingMode rounding_mode, TieBreakingMode tie_breaking_mode>
+class FunctionRounding : public IFunction {
+public:
+    static constexpr auto name = Impl::name;
+    static FunctionPtr create() { return std::make_shared<FunctionRounding>(); }
+
+    String get_name() const override { return name; }
+
+    bool is_variadic() const override { return true; }
+    size_t get_number_of_arguments() const override { return 0; }
+
+    DataTypes get_variadic_argument_types_impl() const override {
+        return Impl::get_variadic_argument_types();
+    }
+
+    /// Get result types by argument types. If the function does not apply to these arguments, throw an exception.
+    DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
+        if ((arguments.empty()) || (arguments.size() > 2)) {
+            LOG(FATAL) << "Number of arguments for function " + get_name() +
+                                  " doesn't match: should be 1 or 2. ";
+        }
+
+        return arguments[0];
+    }
+
+    static Status get_scale_arg(const ColumnWithTypeAndName& arguments, Int16* scale) {
+        const IColumn& scale_column = *arguments.column;
+        if (!is_column_const(scale_column)) {
+            return Status::InvalidArgument("2nd argument for function {} should be constant", name);
+        }
+
+        Field scale_field = assert_cast<const ColumnConst&>(scale_column).get_field();
+
+        Int64 scale64 = scale_field.get<Int64>();
+        if (scale64 > std::numeric_limits<Int16>::max() ||
+            scale64 < std::numeric_limits<Int16>::min()) {
+            return Status::InvalidArgument("Scale argument for function {} is too large: {}", name,
+                                           scale64);
+        }
+
+        *scale = scale64;
+        return Status::OK();
+    }
+
+    bool use_default_implementation_for_constants() const override { return true; }
+    ColumnNumbers get_arguments_that_are_always_constant() const override { return {1}; }
+
+    Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
+                        size_t result, size_t /*input_rows_count*/) override {
+        const ColumnWithTypeAndName& column = block.get_by_position(arguments[0]);
+        Int16 scale_arg = 0;
+        if (arguments.size() == 2) {
+            RETURN_IF_ERROR(get_scale_arg(block.get_by_position(arguments[1]), &scale_arg));
+        }
+
+        ColumnPtr res;
+        auto call = [&](const auto& types) -> bool {
+            using Types = std::decay_t<decltype(types)>;
+            using DataType = typename Types::LeftType;
+
+            if constexpr (IsDataTypeNumber<DataType> || IsDataTypeDecimal<DataType>) {
+                using FieldType = typename DataType::FieldType;
+                res = Dispatcher<FieldType, rounding_mode, tie_breaking_mode>::apply(
+                        column.column.get(), scale_arg);
+                return true;
+            }
+            return false;
+        };
+
+#if !defined(__SSE4_1__)
+        /// In case of "nearbyint" function is used, we should ensure the expected rounding mode for the Banker's rounding.
+        /// Actually it is by default. But we will set it just in case.
+
+        if constexpr (rounding_mode == RoundingMode::Round) {
+            if (0 != fesetround(FE_TONEAREST)) {
+                return Status::InvalidArgument("Cannot set floating point rounding mode");
+            }
+        }
+#endif
+
+        if (!call_on_index_and_data_type<void>(column.type->get_type_id(), call)) {
+            return Status::InvalidArgument("Invalid argument type {} for function {}",
+                                           column.type->get_name(), name);
+        }
+
+        block.replace_by_position(result, std::move(res));
+        return Status::OK();
+    }
+};
+
+} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/test/vec/function/function_math_test.cpp b/be/test/vec/function/function_math_test.cpp
index a93fb02603..88a4378d70 100644
--- a/be/test/vec/function/function_math_test.cpp
+++ b/be/test/vec/function/function_math_test.cpp
@@ -213,28 +213,14 @@ TEST(MathFunctionTest, pow_test) {
     check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
 }
 
-TEST(MathFunctionTest, truncate_test) {
-    std::string func_name = "truncate"; // truncate(x,y)
-
-    InputTypeSet input_types = {TypeIndex::Float64, TypeIndex::Float64};
-
-    DataSet data_set = {{{123.4567, 3.0}, 123.456}, {{-123.4567, 3.0}, -123.456},
-                        {{123.4567, 0.0}, 123.0},   {{-123.4567, 0.0}, -123.0},
-                        {{123.4567, -2.0}, 100.0},  {{-123.4567, -2.0}, -100.0},
-                        {{-123.4567, -3.0}, 0.0}};
-
-    check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
-}
-
 TEST(MathFunctionTest, ceil_test) {
     std::string func_name = "ceil";
 
     InputTypeSet input_types = {TypeIndex::Float64};
 
-    DataSet data_set = {
-            {{2.3}, (int64_t)3}, {{2.8}, (int64_t)3}, {{-2.3}, (int64_t)-2}, {{2.8}, (int64_t)3.0}};
+    DataSet data_set = {{{2.3}, 3.0}, {{2.8}, 3.0}, {{-2.3}, -2.0}, {{2.8}, 3.0}};
 
-    check_function<DataTypeInt64, true>(func_name, input_types, data_set);
+    check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
 }
 
 TEST(MathFunctionTest, floor_test) {
@@ -242,10 +228,9 @@ TEST(MathFunctionTest, floor_test) {
 
     InputTypeSet input_types = {TypeIndex::Float64};
 
-    DataSet data_set = {
-            {{2.3}, (int64_t)2}, {{2.8}, (int64_t)2}, {{-2.3}, (int64_t)-3}, {{-2.8}, (int64_t)-3}};
+    DataSet data_set = {{{2.3}, 2.0}, {{2.8}, 2.0}, {{-2.3}, -3.0}, {{-2.8}, -3.0}};
 
-    check_function<DataTypeInt64, true>(func_name, input_types, data_set);
+    check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
 }
 
 TEST(MathFunctionTest, degrees_test) {
@@ -377,16 +362,8 @@ TEST(MathFunctionTest, round_test) {
     {
         InputTypeSet input_types = {TypeIndex::Float64};
 
-        DataSet data_set = {{{30.1}, (int64_t)30}, {{90.6}, (int64_t)91}, {{Null()}, Null()},
-                            {{0.0}, (int64_t)0},   {{-1.1}, (int64_t)-1}, {{-60.7}, (int64_t)-61}};
-
-        check_function<DataTypeInt64, true>(func_name, input_types, data_set);
-    }
-    {
-        InputTypeSet input_types = {TypeIndex::Float64, TypeIndex::Int32};
-
-        DataSet data_set = {{{3.1415926, 2}, 3.14}, {{3.1415926, 3}, 3.142}, {{Null(), -2}, Null()},
-                            {{193.0, -2}, 200.0},   {{193.0, -1}, 190.0},    {{193.0, -3}, 0.0}};
+        DataSet data_set = {{{30.1}, 30.0}, {{90.6}, 91.0}, {{Null()}, Null()},
+                            {{0.0}, 0.0},   {{-1.1}, -1.0}, {{-60.7}, -61.0}};
 
         check_function<DataTypeFloat64, true>(func_name, input_types, data_set);
     }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
index 02e7b6a979..937aeb7437 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
@@ -72,46 +72,88 @@ public class FunctionCallExpr extends Expr {
             new ImmutableSortedSet.Builder(String.CASE_INSENSITIVE_ORDER)
                     .add("stddev").add("stddev_val").add("stddev_samp").add("stddev_pop")
                     .add("variance").add("variance_pop").add("variance_pop").add("var_samp").add("var_pop").build();
-    public static final Map<String, java.util.function.Function<Type[], Type>> DECIMAL_INFER_RULE;
-    public static final java.util.function.Function<Type[], Type> DEFAULT_DECIMAL_INFER_RULE;
+    public static final Map<String, java.util.function.BiFunction<ArrayList<Expr>, Type, Type>> PRECISION_INFER_RULE;
+    public static final java.util.function.BiFunction<ArrayList<Expr>, Type, Type> DEFAULT_PRECISION_INFER_RULE;
 
     static {
-        java.util.function.Function<Type[], Type> sumRule = (com.google.common.base.Function<Type[], Type>) type -> {
-            Preconditions.checkArgument(type != null && type.length > 0);
-            if (type[0].isDecimalV3()) {
+        java.util.function.BiFunction<ArrayList<Expr>, Type, Type> sumRule = (children, returnType) -> {
+            Preconditions.checkArgument(children != null && children.size() > 0);
+            if (children.get(0).getType().isDecimalV3()) {
                 return ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION,
-                        ((ScalarType) type[0]).getScalarScale());
+                        ((ScalarType) children.get(0).getType()).getScalarScale());
             } else {
-                return type[0];
+                return returnType;
             }
         };
-        DEFAULT_DECIMAL_INFER_RULE = (com.google.common.base.Function<Type[], Type>) type -> {
-            Preconditions.checkArgument(type != null && type.length > 0);
-            return type[0];
+        DEFAULT_PRECISION_INFER_RULE = (children, returnType) -> {
+            if (children != null && children.size() > 0
+                    && children.get(0).getType().isDecimalV3() && returnType.isDecimalV3()) {
+                return children.get(0).getType();
+            } else if (children != null && children.size() > 0 && children.get(0).getType().isDatetimeV2()
+                    && returnType.isDatetimeV2()) {
+                return children.get(0).getType();
+            } else {
+                return returnType;
+            }
+        };
+        java.util.function.BiFunction<ArrayList<Expr>, Type, Type> roundRule = (children, returnType) -> {
+            Preconditions.checkArgument(children != null && children.size() > 0);
+            if (children.size() == 1 && children.get(0).getType().isDecimalV3()) {
+                return ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(), 0);
+            } else if (children.size() == 2) {
+                Preconditions.checkArgument(children.get(1) instanceof IntLiteral
+                                || (children.get(1) instanceof CastExpr
+                                && children.get(1).getChild(0) instanceof IntLiteral),
+                        "2nd argument of function round/floor/ceil/truncate must be literal");
+                if (children.get(1) instanceof CastExpr && children.get(1).getChild(0) instanceof IntLiteral) {
+                    children.get(1).getChild(0).setType(children.get(1).getType());
+                    children.set(1, children.get(1).getChild(0));
+                } else {
+                    children.get(1).setType(Type.INT);
+                }
+                return ScalarType.createDecimalV3Type(children.get(0).getType().getPrecision(),
+                        ((ScalarType) children.get(0).getType()).decimalScale());
+            } else {
+                return returnType;
+            }
         };
-        DECIMAL_INFER_RULE = new HashMap<>();
-        DECIMAL_INFER_RULE.put("sum", sumRule);
-        DECIMAL_INFER_RULE.put("multi_distinct_sum", sumRule);
-        DECIMAL_INFER_RULE.put("avg", (com.google.common.base.Function<Type[], Type>) type -> {
+        PRECISION_INFER_RULE = new HashMap<>();
+        PRECISION_INFER_RULE.put("sum", sumRule);
+        PRECISION_INFER_RULE.put("multi_distinct_sum", sumRule);
+        PRECISION_INFER_RULE.put("avg", (children, returnType) -> {
             // TODO: how to set scale?
-            Preconditions.checkArgument(type != null && type.length > 0);
-            if (type[0].isDecimalV3()) {
+            Preconditions.checkArgument(children != null && children.size() > 0);
+            if (children.get(0).getType().isDecimalV3()) {
                 return ScalarType.createDecimalV3Type(ScalarType.MAX_DECIMAL128_PRECISION,
-                        ((ScalarType) type[0]).getScalarScale());
+                        ((ScalarType) children.get(0).getType()).getScalarScale());
             } else {
-                return type[0];
+                return returnType;
             }
         });
-        DECIMAL_INFER_RULE.put("if", (com.google.common.base.Function<Type[], Type>) type -> {
-            Preconditions.checkArgument(type != null && type.length == 3);
-            if (type[1].isDecimalV3() && type[2].isDecimalV3()) {
+        PRECISION_INFER_RULE.put("if", (children, returnType) -> {
+            Preconditions.checkArgument(children != null && children.size() == 3);
+            if (children.get(1).getType().isDecimalV3() && children.get(2).getType().isDecimalV3()) {
                 return ScalarType.createDecimalV3Type(
-                        Math.max(((ScalarType) type[1]).decimalPrecision(), ((ScalarType) type[2]).decimalPrecision()),
-                        Math.max(((ScalarType) type[1]).decimalScale(), ((ScalarType) type[2]).decimalScale()));
+                        Math.max(((ScalarType) children.get(1).getType()).decimalPrecision(),
+                                ((ScalarType) children.get(2).getType()).decimalPrecision()),
+                        Math.max(((ScalarType) children.get(1).getType()).decimalScale(),
+                                ((ScalarType) children.get(2).getType()).decimalScale()));
+            } else if (children.get(1).getType().isDatetimeV2() && children.get(2).getType().isDatetimeV2()) {
+                return ((ScalarType) children.get(1).getType()).decimalScale()
+                        > ((ScalarType) children.get(2).getType()).decimalScale()
+                        ? children.get(1).getType() : children.get(2).getType();
             } else {
-                return type[0];
+                return returnType;
             }
         });
+
+        PRECISION_INFER_RULE.put("round", roundRule);
+        PRECISION_INFER_RULE.put("ceil", roundRule);
+        PRECISION_INFER_RULE.put("floor", roundRule);
+        PRECISION_INFER_RULE.put("dround", roundRule);
+        PRECISION_INFER_RULE.put("dceil", roundRule);
+        PRECISION_INFER_RULE.put("dfloor", roundRule);
+        PRECISION_INFER_RULE.put("truncate", roundRule);
     }
 
     public static final ImmutableSet<String> TIME_FUNCTIONS_WITH_PRECISION =
@@ -1068,11 +1110,6 @@ public class FunctionCallExpr extends Expr {
             }
             fn = getBuiltinFunction(fnName.getFunction(), childTypes,
                     Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
-            if (fn != null && fn.getArgs()[2].isDatetime() && childTypes[2].isDatetimeV2()) {
-                fn.setArgType(childTypes[2], 2);
-            } else if (fn != null && fn.getArgs()[2].isDatetime() && childTypes[2].isDateV2()) {
-                fn.setArgType(ScalarType.DATETIMEV2, 2);
-            }
             if (fn != null && childTypes[2].isDate()) {
                 // cast date to datetime
                 uncheckedCastChild(ScalarType.DATETIME, 2);
@@ -1126,18 +1163,6 @@ public class FunctionCallExpr extends Expr {
             }
             fn = getBuiltinFunction(fnName.getFunction(), childTypes,
                 Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
-            if (fn != null && fn.getArgs()[1].isDatetime() && childTypes[1].isDatetimeV2()) {
-                fn.setArgType(childTypes[1], 1);
-            } else if (fn != null && fn.getArgs()[1].isDatetime() && childTypes[1].isDateV2()) {
-                fn.setArgType(ScalarType.DATETIMEV2, 1);
-            }
-            if (fn != null && childTypes[1].isDate()) {
-                // cast date to datetime
-                uncheckedCastChild(ScalarType.DATETIME, 1);
-            } else if (fn != null && childTypes[1].isDateV2()) {
-                // cast date to datetime
-                uncheckedCastChild(ScalarType.DATETIMEV2, 1);
-            }
         } else if (fnName.getFunction().equalsIgnoreCase("if")) {
             Type[] childTypes = collectChildReturnTypes();
             Type assignmentCompatibleType = ScalarType.getAssignmentCompatibleType(childTypes[1], childTypes[2], true);
@@ -1216,8 +1241,6 @@ public class FunctionCallExpr extends Expr {
             fn.setReturnType(new ArrayType(getChild(0).type));
         }
 
-        applyAutoTypeConversionForDatetimeV2();
-
         if (fnName.getFunction().equalsIgnoreCase("from_unixtime")
                 || fnName.getFunction().equalsIgnoreCase("date_format")) {
             // if has only one child, it has default time format: yyyy-MM-dd HH:mm:ss.SSSSSS
@@ -1329,6 +1352,13 @@ public class FunctionCallExpr extends Expr {
             } else {
                 this.type = ScalarType.getDefaultDateType(Type.DATETIME);
             }
+        } else if (TIME_FUNCTIONS_WITH_PRECISION.contains(fnName.getFunction().toLowerCase())
+                && fn.getReturnType().isDatetimeV2()) {
+            if (children.size() == 1 && children.get(0) instanceof IntLiteral) {
+                this.type = ScalarType.createDatetimeV2Type((int) ((IntLiteral) children.get(0)).getLongValue());
+            } else if (children.size() == 1) {
+                this.type = ScalarType.createDatetimeV2Type(6);
+            }
         } else {
             this.type = fn.getReturnType();
         }
@@ -1370,39 +1400,16 @@ public class FunctionCallExpr extends Expr {
             fn.setReturnType(Type.MAX_DECIMALV2_TYPE);
         }
 
-        if (this.type.isDecimalV3()) {
+        if (this.type.isDecimalV3() || (this.type.isDatetimeV2()
+                && !TIME_FUNCTIONS_WITH_PRECISION.contains(fnName.getFunction().toLowerCase()))) {
             // TODO(gabriel): If type exceeds max precision of DECIMALV3, we should change it to a double function
-            this.type = DECIMAL_INFER_RULE.getOrDefault(fnName.getFunction(), DEFAULT_DECIMAL_INFER_RULE)
-                    .apply(collectChildReturnTypes());
+            this.type = PRECISION_INFER_RULE.getOrDefault(fnName.getFunction(), DEFAULT_PRECISION_INFER_RULE)
+                    .apply(children, this.type);
         }
         // rewrite return type if is nested type function
         analyzeNestedFunction();
     }
 
-    private void applyAutoTypeConversionForDatetimeV2() {
-        // Rule1: Now we treat datetimev2 with different precisions as different types and we only register functions
-        // for datetimev2(0). So we must apply an automatic type conversion from datetimev2(0) to the real type.
-        if (fn.getArgs().length == children.size() && fn.getArgs().length > 0) {
-            if (fn.getArgs()[0].isDatetimeV2() && children.get(0).getType().isDatetimeV2()) {
-                fn.setArgType(children.get(0).getType(), 0);
-                if (fn.getReturnType().isDatetimeV2()) {
-                    fn.setReturnType(children.get(0).getType());
-                }
-            }
-        }
-
-        // Rule2: For functions in TIME_FUNCTIONS_WITH_PRECISION, we can't figure out which function should be use when
-        // searching in FunctionSet. So we adjust the return type by hand here.
-        if (TIME_FUNCTIONS_WITH_PRECISION.contains(fnName.getFunction().toLowerCase())
-                && fn != null && fn.getReturnType().isDatetimeV2()) {
-            if (children.size() == 1 && children.get(0) instanceof IntLiteral) {
-                fn.setReturnType(ScalarType.createDatetimeV2Type((int) ((IntLiteral) children.get(0)).getLongValue()));
-            } else if (children.size() == 1) {
-                fn.setReturnType(ScalarType.createDatetimeV2Type(6));
-            }
-        }
-    }
-
     // if return type is nested type, need to be determined the sub-element type
     private void analyzeNestedFunction() {
         // array
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java
index ac9ebec41e..f1a608a9b8 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java
@@ -183,10 +183,6 @@ public class Function implements Writable {
         this.retType = type;
     }
 
-    public void setArgType(Type type, int i) {
-        argTypes[i] = type;
-    }
-
     public Type[] getArgs() {
         return argTypes;
     }
diff --git a/fe/fe-core/src/test/java/org/apache/doris/planner/ConstantExpressTest.java b/fe/fe-core/src/test/java/org/apache/doris/planner/ConstantExpressTest.java
index 6e4b66bda8..793e5af2f2 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/planner/ConstantExpressTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/planner/ConstantExpressTest.java
@@ -166,13 +166,6 @@ public class ConstantExpressTest {
                 "0.1");
     }
 
-    @Test
-    public void testMath() throws Exception {
-        testConstantExpressResult(
-                "select floor(2.3);",
-                "2");
-    }
-
     @Test
     public void testPredicate() throws Exception {
         testConstantExpressResult(
diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py
index d34b0ba34b..78a6d2803e 100755
--- a/gensrc/script/doris_builtins_functions.py
+++ b/gensrc/script/doris_builtins_functions.py
@@ -1825,15 +1825,69 @@ visible_functions = [
     [['atan'], 'DOUBLE', ['DOUBLE'],
             '_ZN5doris13MathFunctions4atanEPN9doris_udf15FunctionContextERKNS1_9DoubleValE', '', '', 'vec', ''],
 
-    [['ceil', 'ceiling', 'dceil'], 'BIGINT', ['DOUBLE'],
+    [['ceil', 'ceiling', 'dceil'], 'DOUBLE', ['DOUBLE'],
             '_ZN5doris13MathFunctions4ceilEPN9doris_udf15FunctionContextERKNS1_9DoubleValE', '', '', 'vec', ''],
-    [['floor', 'dfloor'], 'BIGINT', ['DOUBLE'],
+    [['floor', 'dfloor'], 'DOUBLE', ['DOUBLE'],
             '_ZN5doris13MathFunctions5floorEPN9doris_udf15FunctionContextERKNS1_9DoubleValE', '', '', 'vec', ''],
-    [['round', 'dround'], 'BIGINT', ['DOUBLE'],
+    [['round', 'dround'], 'DOUBLE', ['DOUBLE'],
+            '_ZN5doris13MathFunctions5roundEPN9doris_udf15FunctionContextERKNS1_9DoubleValE', '', '', 'vec', ''],
+    [['ceil', 'ceiling', 'dceil'], 'DECIMAL32', ['DECIMAL32'],
+            '_ZN5doris13MathFunctions4ceilEPN9doris_udf15FunctionContextERKNS1_9DoubleValE', '', '', 'vec', ''],
+    [['floor', 'dfloor'], 'DECIMAL32', ['DECIMAL32'],
+            '_ZN5doris13MathFunctions5floorEPN9doris_udf15FunctionContextERKNS1_9DoubleValE', '', '', 'vec', ''],
+    [['round', 'dround'], 'DECIMAL32', ['DECIMAL32'],
+            '_ZN5doris13MathFunctions5roundEPN9doris_udf15FunctionContextERKNS1_9DoubleValE', '', '', 'vec', ''],
+    [['ceil', 'ceiling', 'dceil'], 'DECIMAL64', ['DECIMAL64'],
+            '_ZN5doris13MathFunctions4ceilEPN9doris_udf15FunctionContextERKNS1_9DoubleValE', '', '', 'vec', ''],
+    [['floor', 'dfloor'], 'DECIMAL64', ['DECIMAL64'],
+            '_ZN5doris13MathFunctions5floorEPN9doris_udf15FunctionContextERKNS1_9DoubleValE', '', '', 'vec', ''],
+    [['round', 'dround'], 'DECIMAL64', ['DECIMAL64'],
+            '_ZN5doris13MathFunctions5roundEPN9doris_udf15FunctionContextERKNS1_9DoubleValE', '', '', 'vec', ''],
+    [['ceil', 'ceiling', 'dceil'], 'DECIMAL128', ['DECIMAL128'],
+            '_ZN5doris13MathFunctions4ceilEPN9doris_udf15FunctionContextERKNS1_9DoubleValE', '', '', 'vec', ''],
+    [['floor', 'dfloor'], 'DECIMAL128', ['DECIMAL128'],
+            '_ZN5doris13MathFunctions5floorEPN9doris_udf15FunctionContextERKNS1_9DoubleValE', '', '', 'vec', ''],
+    [['round', 'dround'], 'DECIMAL128', ['DECIMAL128'],
             '_ZN5doris13MathFunctions5roundEPN9doris_udf15FunctionContextERKNS1_9DoubleValE', '', '', 'vec', ''],
     [['round', 'dround'], 'DOUBLE', ['DOUBLE', 'INT'],
             '_ZN5doris13MathFunctions11round_up_toEPN9doris_udf'
             '15FunctionContextERKNS1_9DoubleValERKNS1_6IntValE', '', '', 'vec', ''],
+    [['round', 'dround'], 'DECIMAL32', ['DECIMAL32', 'INT'],
+            '_ZN5doris13MathFunctions11round_up_toEPN9doris_udf'
+            '15FunctionContextERKNS1_9DoubleValERKNS1_6IntValE', '', '', 'vec', ''],
+    [['round', 'dround'], 'DECIMAL64', ['DECIMAL64', 'INT'],
+            '_ZN5doris13MathFunctions11round_up_toEPN9doris_udf'
+            '15FunctionContextERKNS1_9DoubleValERKNS1_6IntValE', '', '', 'vec', ''],
+    [['round', 'dround'], 'DECIMAL128', ['DECIMAL128', 'INT'],
+            '_ZN5doris13MathFunctions11round_up_toEPN9doris_udf'
+            '15FunctionContextERKNS1_9DoubleValERKNS1_6IntValE', '', '', 'vec', ''],
+    [['floor', 'dfloor'], 'DECIMAL32', ['DECIMAL32', 'INT'],
+            '_ZN5doris13MathFunctions11round_up_toEPN9doris_udf'
+            '15FunctionContextERKNS1_9DoubleValERKNS1_6IntValE', '', '', 'vec', ''],
+    [['floor', 'dfloor'], 'DECIMAL64', ['DECIMAL64', 'INT'],
+            '_ZN5doris13MathFunctions11round_up_toEPN9doris_udf'
+            '15FunctionContextERKNS1_9DoubleValERKNS1_6IntValE', '', '', 'vec', ''],
+    [['floor', 'dfloor'], 'DECIMAL128', ['DECIMAL128', 'INT'],
+            '_ZN5doris13MathFunctions11round_up_toEPN9doris_udf'
+            '15FunctionContextERKNS1_9DoubleValERKNS1_6IntValE', '', '', 'vec', ''],
+    [['ceil', 'dceil'], 'DECIMAL32', ['DECIMAL32', 'INT'],
+            '_ZN5doris13MathFunctions11round_up_toEPN9doris_udf'
+            '15FunctionContextERKNS1_9DoubleValERKNS1_6IntValE', '', '', 'vec', ''],
+    [['ceil', 'dceil'], 'DECIMAL64', ['DECIMAL64', 'INT'],
+            '_ZN5doris13MathFunctions11round_up_toEPN9doris_udf'
+            '15FunctionContextERKNS1_9DoubleValERKNS1_6IntValE', '', '', 'vec', ''],
+    [['ceil', 'dceil'], 'DECIMAL128', ['DECIMAL128', 'INT'],
+            '_ZN5doris13MathFunctions11round_up_toEPN9doris_udf'
+            '15FunctionContextERKNS1_9DoubleValERKNS1_6IntValE', '', '', 'vec', ''],
+    [['truncate'], 'DECIMAL32', ['DECIMAL32', 'INT'],
+            '_ZN5doris13MathFunctions11round_up_toEPN9doris_udf'
+            '15FunctionContextERKNS1_9DoubleValERKNS1_6IntValE', '', '', 'vec', ''],
+    [['truncate'], 'DECIMAL64', ['DECIMAL64', 'INT'],
+            '_ZN5doris13MathFunctions11round_up_toEPN9doris_udf'
+            '15FunctionContextERKNS1_9DoubleValERKNS1_6IntValE', '', '', 'vec', ''],
+    [['truncate'], 'DECIMAL128', ['DECIMAL128', 'INT'],
+            '_ZN5doris13MathFunctions11round_up_toEPN9doris_udf'
+            '15FunctionContextERKNS1_9DoubleValERKNS1_6IntValE', '', '', 'vec', ''],
     [['truncate'], 'DOUBLE', ['DOUBLE', 'INT'],
             '_ZN5doris13MathFunctions8truncateEPN9doris_udf'
             '15FunctionContextERKNS1_9DoubleValERKNS1_6IntValE', '', '', 'vec', ''],
diff --git a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
index 7723af2e13..1c672c02f4 100644
--- a/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
+++ b/regression-test/data/query_p0/sql_functions/math_functions/test_round.out
@@ -1,13 +1,19 @@
 -- This file is automatically generated. You should know what you did if you want to edit this
 -- !select --
-10
+10.0
 
 -- !select --
 10.12
 
 -- !select --
-10
+16.030	16.03000	16.03000
 
 -- !select --
-10.12
+16.020	16.02000	16.02000
+
+-- !select --
+16.030	16.03000	16.03000
+
+-- !select --
+16.020	16.02000	16.02000
 
diff --git a/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy b/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy
index f40f3b857d..a79417b885 100644
--- a/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy
+++ b/regression-test/suites/query_p0/sql_functions/math_functions/test_round.groovy
@@ -16,15 +16,25 @@
 // under the License.
 
 suite("test_round") {
-    // non vectorized
-    sql """ set enable_vectorized_engine = false """
-
-    qt_select "SELECT round(10.12345)"
-    qt_select "SELECT round(10.12345, 2)"
-
     // vectorized
     sql """ set enable_vectorized_engine = true """
 
     qt_select "SELECT round(10.12345)"
     qt_select "SELECT round(10.12345, 2)"
+
+    def tableName = "test_round"
+    sql """DROP TABLE IF EXISTS `${tableName}`"""
+    sql """ CREATE TABLE `${tableName}` (
+        `col1` DECIMALV3(6,3) COMMENT "",
+        `col2` DECIMALV3(16,5) COMMENT "",
+        `col3` DECIMALV3(32,5) COMMENT "")
+        DUPLICATE KEY(`col1`) DISTRIBUTED BY HASH(`col1`)
+        PROPERTIES ( "replication_num" = "1" ); """
+
+    sql """ insert into `${tableName}` values(16.025, 16.025, 16.025); """
+    qt_select """ SELECT round(col1, 2), round(col2, 2), round(col3, 2) FROM `${tableName}`; """
+    qt_select """ SELECT floor(col1, 2), floor(col2, 2), floor(col3, 2) FROM `${tableName}`; """
+    qt_select """ SELECT ceil(col1, 2), ceil(col2, 2), ceil(col3, 2) FROM `${tableName}`; """
+    qt_select """ SELECT truncate(col1, 2), truncate(col2, 2), truncate(col3, 2) FROM `${tableName}`; """
+    sql """ DROP TABLE IF EXISTS ${tableName} """
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org