You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by yi...@apache.org on 2022/07/30 10:39:02 UTC

[doris] branch master updated: [refactor][vectorized] refactor first/last value agg functions (#10661)

This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 1f30e563a7 [refactor][vectorized] refactor first/last value agg functions (#10661)
1f30e563a7 is described below

commit 1f30e563a708cc5c1800f8d469a43f2f011592f2
Author: zhangstar333 <87...@users.noreply.github.com>
AuthorDate: Sat Jul 30 18:38:56 2022 +0800

    [refactor][vectorized] refactor first/last value agg functions (#10661)
    
    * refactor first and last
    [refactor][vectorized] refactor first/last value agg functions
    
    * add some change
    
    * remove first/last about always nullable
    
    * remove always nullable and register it
    
    * refactor value remove bool null flag
    
    * refactor win first last to ptr and pos
---
 be/src/vec/CMakeLists.txt                          |   1 -
 .../aggregate_function_reader.cpp                  |  22 +-
 .../aggregate_function_reader.h                    |   2 +-
 .../aggregate_function_reader_first_last.h         | 287 +++++++++++++++++++
 .../aggregate_function_simple_factory.cpp          |   5 +-
 .../aggregate_function_window.cpp                  |  89 ++++--
 .../aggregate_function_window.h                    | 312 ++++-----------------
 be/src/vec/exec/join/vhash_join_node.cpp           |   8 -
 be/src/vec/utils/template_helpers.hpp              |  22 +-
 9 files changed, 434 insertions(+), 314 deletions(-)

diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt
index 6be1a9c797..4e30835e4b 100644
--- a/be/src/vec/CMakeLists.txt
+++ b/be/src/vec/CMakeLists.txt
@@ -41,7 +41,6 @@ set(VEC_FILES
   aggregate_functions/aggregate_function_group_concat.cpp
   aggregate_functions/aggregate_function_percentile_approx.cpp
   aggregate_functions/aggregate_function_simple_factory.cpp
-  aggregate_functions/aggregate_function_java_udaf.h
   aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
   columns/column.cpp
   columns/column_array.cpp
diff --git a/be/src/vec/aggregate_functions/aggregate_function_reader.cpp b/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
index 3e27f30b85..8a3bea08bd 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
@@ -44,23 +44,19 @@ void register_aggregate_function_replace_reader_load(AggregateFunctionSimpleFact
         factory.register_function(name + suffix, creator, nullable);
     };
 
-    register_function("replace", AGG_READER_SUFFIX, create_aggregate_function_first<false, true>,
-                      false);
-    register_function("replace", AGG_READER_SUFFIX, create_aggregate_function_first<true, true>,
-                      true);
-    register_function("replace", AGG_LOAD_SUFFIX, create_aggregate_function_last<false, false>,
-                      false);
-    register_function("replace", AGG_LOAD_SUFFIX, create_aggregate_function_last<true, false>,
-                      true);
+    register_function("replace", AGG_READER_SUFFIX, create_aggregate_function_first<true>, false);
+    register_function("replace", AGG_READER_SUFFIX, create_aggregate_function_first<true>, true);
+    register_function("replace", AGG_LOAD_SUFFIX, create_aggregate_function_last<false>, false);
+    register_function("replace", AGG_LOAD_SUFFIX, create_aggregate_function_last<false>, true);
 
     register_function("replace_if_not_null", AGG_READER_SUFFIX,
-                      create_aggregate_function_first_non_null_value<false, true>, false);
+                      create_aggregate_function_first_non_null_value<true>, false);
     register_function("replace_if_not_null", AGG_READER_SUFFIX,
-                      create_aggregate_function_first_non_null_value<true, true>, true);
+                      create_aggregate_function_first_non_null_value<true>, true);
     register_function("replace_if_not_null", AGG_LOAD_SUFFIX,
-                      create_aggregate_function_last_non_null_value<false, false>, false);
+                      create_aggregate_function_last_non_null_value<false>, false);
     register_function("replace_if_not_null", AGG_LOAD_SUFFIX,
-                      create_aggregate_function_last_non_null_value<true, false>, true);
+                      create_aggregate_function_last_non_null_value<false>, true);
 }
 
-} // namespace doris::vectorized
+} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/aggregate_function_reader.h b/be/src/vec/aggregate_functions/aggregate_function_reader.h
index 86fea6f079..626c06571b 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_reader.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_reader.h
@@ -20,9 +20,9 @@
 #include "vec/aggregate_functions/aggregate_function_bitmap.h"
 #include "vec/aggregate_functions/aggregate_function_hll_union_agg.h"
 #include "vec/aggregate_functions/aggregate_function_min_max.h"
+#include "vec/aggregate_functions/aggregate_function_reader_first_last.h"
 #include "vec/aggregate_functions/aggregate_function_simple_factory.h"
 #include "vec/aggregate_functions/aggregate_function_sum.h"
-#include "vec/aggregate_functions/aggregate_function_window.h"
 
 namespace doris::vectorized {
 
diff --git a/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h b/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h
new file mode 100644
index 0000000000..4b7c1e0c98
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h
@@ -0,0 +1,287 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "factory_helpers.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/helpers.h"
+#include "vec/columns/column_nullable.h"
+#include "vec/columns/column_vector.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_nullable.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/data_types/data_type_string.h"
+#include "vec/io/io_helper.h"
+#include "vec/utils/template_helpers.hpp"
+
+namespace doris::vectorized {
+
+template <typename ColVecType, bool arg_is_nullable>
+struct Value {
+public:
+    bool is_null() const {
+        if (_ptr == nullptr) {
+            return true;
+        }
+        if constexpr (arg_is_nullable) {
+            return assert_cast<const ColumnNullable*>(_ptr)->is_null_at(_offset);
+        }
+        return false;
+    }
+
+    StringRef get_value() const {
+        if constexpr (arg_is_nullable) {
+            auto* col = assert_cast<const ColumnNullable*>(_ptr);
+            return assert_cast<const ColVecType&>(col->get_nested_column()).get_data_at(_offset);
+        } else {
+            return assert_cast<const ColVecType*>(_ptr)->get_data_at(_offset);
+        }
+    }
+
+    void set_value(const IColumn* column, size_t row) {
+        _ptr = column;
+        _offset = row;
+    }
+
+    void reset() {
+        _ptr = nullptr;
+        _offset = 0;
+    }
+
+protected:
+    const IColumn* _ptr = nullptr;
+    size_t _offset = 0;
+};
+
+template <typename ColVecType, bool arg_is_nullable>
+struct CopiedValue : public Value<ColVecType, arg_is_nullable> {
+public:
+    StringRef get_value() const { return _copied_value; }
+
+    bool is_null() const { return this->_ptr == nullptr; }
+
+    void set_value(const IColumn* column, size_t row) {
+        // here _ptr, maybe null at row, so call reset to set nullptr
+        // But we will use is_null() check first, others have set _ptr column to a meaningless address
+        // because the address have meaningless, only need it to check is nullptr
+        this->_ptr = (IColumn*)0x00000001;
+        if constexpr (arg_is_nullable) {
+            auto* col = assert_cast<const ColumnNullable*>(column);
+            if (col->is_null_at(row)) {
+                this->reset();
+                return;
+            } else {
+                _copied_value = assert_cast<const ColVecType&>(col->get_nested_column())
+                                        .get_data_at(row)
+                                        .to_string();
+            }
+        } else {
+            _copied_value = assert_cast<const ColVecType*>(column)->get_data_at(row).to_string();
+        }
+    }
+
+private:
+    std::string _copied_value;
+};
+
+template <typename ColVecType, bool result_is_nullable, bool arg_is_nullable, bool is_copy>
+struct ReaderFirstAndLastData {
+public:
+    using StoreType = std::conditional_t<is_copy, CopiedValue<ColVecType, arg_is_nullable>,
+                                         Value<ColVecType, arg_is_nullable>>;
+    static constexpr bool nullable = arg_is_nullable;
+
+    void reset() {
+        _data_value.reset();
+        _has_value = false;
+    }
+
+    void insert_result_into(IColumn& to) const {
+        if constexpr (result_is_nullable) {
+            if (_data_value.is_null()) { //_ptr == nullptr || null data at row
+                auto& col = assert_cast<ColumnNullable&>(to);
+                col.insert_default();
+            } else {
+                auto& col = assert_cast<ColumnNullable&>(to);
+                //get_value will never get null value
+                const StringRef& value = _data_value.get_value();
+                col.get_null_map_data().push_back(0);
+                assert_cast<ColVecType&>(col.get_nested_column())
+                        .insert_data(value.data, value.size);
+            }
+        } else {
+            const StringRef& value = _data_value.get_value();
+            assert_cast<ColVecType&>(to).insert_data(value.data, value.size);
+        }
+    }
+
+    // here not check the columns[0] is null at the row,
+    // but it is need to check in other
+    void set_value(const IColumn** columns, size_t pos) {
+        _data_value.set_value(columns[0], pos);
+        _has_value = true;
+    }
+
+    bool has_set_value() { return _has_value; }
+
+protected:
+    StoreType _data_value;
+    bool _has_value = false;
+};
+
+template <typename Data>
+struct ReaderFunctionFirstData : Data {
+    void add(int64_t row, const IColumn** columns) {
+        if (this->has_set_value()) {
+            return;
+        }
+        this->set_value(columns, row);
+    }
+    static const char* name() { return "first_value"; }
+};
+
+template <typename Data>
+struct ReaderFunctionFirstNonNullData : Data {
+    void add(int64_t row, const IColumn** columns) {
+        if (this->has_set_value()) {
+            return;
+        }
+        if constexpr (Data::nullable) {
+            const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
+            if (nullable_column->is_null_at(row)) {
+                return;
+            }
+        }
+        this->set_value(columns, row);
+    }
+    static const char* name() { return "first_non_null_value"; }
+};
+
+template <typename Data>
+struct ReaderFunctionLastData : Data {
+    void add(int64_t row, const IColumn** columns) { this->set_value(columns, row); }
+    static const char* name() { return "last_value"; }
+};
+
+template <typename Data>
+struct ReaderFunctionLastNonNullData : Data {
+    void add(int64_t row, const IColumn** columns) {
+        if constexpr (Data::nullable) {
+            const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
+            if (nullable_column->is_null_at(row)) {
+                return;
+            }
+        }
+        this->set_value(columns, row);
+    }
+
+    static const char* name() { return "last_non_null_value"; }
+};
+
+template <typename Data>
+class ReaderFunctionData final
+        : public IAggregateFunctionDataHelper<Data, ReaderFunctionData<Data>> {
+public:
+    ReaderFunctionData(const DataTypes& argument_types)
+            : IAggregateFunctionDataHelper<Data, ReaderFunctionData<Data>>(argument_types, {}),
+              _argument_type(argument_types[0]) {}
+
+    String get_name() const override { return Data::name(); }
+
+    DataTypePtr get_return_type() const override { return _argument_type; }
+
+    void insert_result_into(ConstAggregateDataPtr place, IColumn& to) const override {
+        this->data(place).insert_result_into(to);
+    }
+
+    void add(AggregateDataPtr place, const IColumn** columns, size_t row_num,
+             Arena* arena) const override {
+        this->data(place).add(row_num, columns);
+    }
+
+    void reset(AggregateDataPtr place) const override { this->data(place).reset(); }
+
+    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
+                                int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
+                                Arena* arena) const override {
+        LOG(FATAL) << "ReaderFunctionData do not support add_range_single_place";
+    }
+    void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena*) const override {
+        LOG(FATAL) << "ReaderFunctionData do not support merge";
+    }
+    void serialize(ConstAggregateDataPtr place, BufferWritable& buf) const override {
+        LOG(FATAL) << "ReaderFunctionData do not support serialize";
+    }
+    void deserialize(AggregateDataPtr place, BufferReadable& buf, Arena*) const override {
+        LOG(FATAL) << "ReaderFunctionData do not support deserialize";
+    }
+
+private:
+    DataTypePtr _argument_type;
+};
+
+template <template <typename> class AggregateFunctionTemplate, template <typename> class Impl,
+          bool result_is_nullable, bool arg_is_nullable, bool is_copy = false>
+static IAggregateFunction* create_function_single_value(const String& name,
+                                                        const DataTypes& argument_types,
+                                                        const Array& parameters) {
+    auto type = remove_nullable(argument_types[0]);
+    WhichDataType which(*type);
+
+#define DISPATCH(TYPE, COLUMN_TYPE)                                       \
+    if (which.idx == TypeIndex::TYPE)                                     \
+        return new AggregateFunctionTemplate<Impl<ReaderFirstAndLastData< \
+                COLUMN_TYPE, result_is_nullable, arg_is_nullable, is_copy>>>(argument_types);
+    TYPE_TO_COLUMN_TYPE(DISPATCH)
+#undef DISPATCH
+
+    LOG(FATAL) << "with unknowed type, failed in  create_aggregate_function_" << name
+               << " and type is: " << argument_types[0]->get_name();
+    return nullptr;
+}
+
+#define CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(CREATE_FUNCTION_NAME, FUNCTION_DATA)            \
+    template <bool is_copy>                                                                       \
+    AggregateFunctionPtr CREATE_FUNCTION_NAME(const std::string& name,                            \
+                                              const DataTypes& argument_types,                    \
+                                              const Array& parameters, bool result_is_nullable) { \
+        const bool arg_is_nullable = argument_types[0]->is_nullable();                            \
+        AggregateFunctionPtr res = nullptr;                                                       \
+        std::visit(                                                                               \
+                [&](auto result_is_nullable, auto arg_is_nullable) {                              \
+                    res = AggregateFunctionPtr(                                                   \
+                            create_function_single_value<ReaderFunctionData, FUNCTION_DATA,       \
+                                                         result_is_nullable, arg_is_nullable,     \
+                                                         is_copy>(name, argument_types,           \
+                                                                  parameters));                   \
+                },                                                                                \
+                make_bool_variant(result_is_nullable), make_bool_variant(arg_is_nullable));       \
+        if (!res) {                                                                               \
+            LOG(WARNING) << " failed in  create_aggregate_function_" << name                      \
+                         << " and type is: " << argument_types[0]->get_name();                    \
+        }                                                                                         \
+        return res;                                                                               \
+    }
+
+CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_first, ReaderFunctionFirstData);
+CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_first_non_null_value,
+                                          ReaderFunctionFirstNonNullData);
+CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_last, ReaderFunctionLastData);
+CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_last_non_null_value,
+                                          ReaderFunctionLastNonNullData);
+} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index 73779f8ffa..79d21985bc 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -39,7 +39,8 @@ void register_aggregate_function_HLL_union_agg(AggregateFunctionSimpleFactory& f
 void register_aggregate_function_uniq(AggregateFunctionSimpleFactory& factory);
 void register_aggregate_function_bitmap(AggregateFunctionSimpleFactory& factory);
 void register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& factory);
-void register_aggregate_function_window_lead_lag(AggregateFunctionSimpleFactory& factory);
+void register_aggregate_function_window_lead_lag_first_last(
+        AggregateFunctionSimpleFactory& factory);
 void register_aggregate_function_stddev_variance_pop(AggregateFunctionSimpleFactory& factory);
 void register_aggregate_function_stddev_variance_samp(AggregateFunctionSimpleFactory& factory);
 void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory);
@@ -81,7 +82,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
 
         register_aggregate_function_stddev_variance_samp(instance);
         register_aggregate_function_replace_reader_load(instance);
-        register_aggregate_function_window_lead_lag(instance);
+        register_aggregate_function_window_lead_lag_first_last(instance);
         register_aggregate_function_HLL_union_agg(instance);
         register_aggregate_function_percentile_approx(instance);
     });
diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.cpp b/be/src/vec/aggregate_functions/aggregate_function_window.cpp
index 1a342d805a..02b283ab2d 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_window.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_window.cpp
@@ -22,7 +22,7 @@
 
 #include "common/logging.h"
 #include "vec/aggregate_functions/aggregate_function_simple_factory.h"
-#include "vec/aggregate_functions/factory_helpers.h"
+#include "vec/utils/template_helpers.hpp"
 
 namespace doris::vectorized {
 
@@ -62,25 +62,59 @@ AggregateFunctionPtr create_aggregate_function_ntile(const std::string& name,
     return std::make_shared<WindowFunctionNTile>(argument_types, parameters);
 }
 
-template <bool is_nullable>
-AggregateFunctionPtr create_aggregate_function_lag(const std::string& name,
-                                                   const DataTypes& argument_types,
-                                                   const Array& parameters,
-                                                   const bool result_is_nullable) {
-    return AggregateFunctionPtr(
-            create_function_single_value<WindowFunctionData, WindowFunctionLagData, is_nullable>(
-                    name, argument_types, parameters));
+template <template <typename> class AggregateFunctionTemplate,
+          template <typename ColVecType, bool, bool> class Data, template <typename> class Impl,
+          bool result_is_nullable, bool arg_is_nullable>
+static IAggregateFunction* create_function_lead_lag_first_last(const String& name,
+                                                               const DataTypes& argument_types,
+                                                               const Array& parameters) {
+    auto type = remove_nullable(argument_types[0]);
+    WhichDataType which(*type);
+
+#define DISPATCH(TYPE, COLUMN_TYPE)           \
+    if (which.idx == TypeIndex::TYPE)         \
+        return new AggregateFunctionTemplate< \
+                Impl<Data<COLUMN_TYPE, result_is_nullable, arg_is_nullable>>>(argument_types);
+    TYPE_TO_BASIC_COLUMN_TYPE(DISPATCH)
+#undef DISPATCH
+
+    LOG(FATAL) << "with unknowed type, failed in  create_aggregate_function_" << name
+               << " and type is: " << argument_types[0]->get_name();
+    return nullptr;
 }
 
-template <bool is_nullable>
-AggregateFunctionPtr create_aggregate_function_lead(const std::string& name,
-                                                    const DataTypes& argument_types,
-                                                    const Array& parameters,
-                                                    const bool result_is_nullable) {
-    return AggregateFunctionPtr(
-            create_function_single_value<WindowFunctionData, WindowFunctionLeadData, is_nullable>(
-                    name, argument_types, parameters));
-}
+#define CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(CREATE_FUNCTION_NAME, FUNCTION_DATA,             \
+                                                  FUNCTION_IMPL)                                   \
+    AggregateFunctionPtr CREATE_FUNCTION_NAME(                                                     \
+            const std::string& name, const DataTypes& argument_types, const Array& parameters,     \
+            const bool result_is_nullable) {                                                       \
+        const bool arg_is_nullable = argument_types[0]->is_nullable();                             \
+        AggregateFunctionPtr res = nullptr;                                                        \
+                                                                                                   \
+        std::visit(                                                                                \
+                [&](auto result_is_nullable, auto arg_is_nullable) {                               \
+                    res = AggregateFunctionPtr(                                                    \
+                            create_function_lead_lag_first_last<WindowFunctionData, FUNCTION_DATA, \
+                                                                FUNCTION_IMPL, result_is_nullable, \
+                                                                arg_is_nullable>(                  \
+                                    name, argument_types, parameters));                            \
+                },                                                                                 \
+                make_bool_variant(result_is_nullable), make_bool_variant(arg_is_nullable));        \
+        if (!res) {                                                                                \
+            LOG(WARNING) << " failed in  create_aggregate_function_" << name                       \
+                         << " and type is: " << argument_types[0]->get_name();                     \
+        }                                                                                          \
+        return res;                                                                                \
+    }
+
+CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_lag, LeadLagData,
+                                          WindowFunctionLagImpl);
+CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_lead, LeadLagData,
+                                          WindowFunctionLeadImpl);
+CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_first, FirstLastData,
+                                          WindowFunctionFirstImpl);
+CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_last, FirstLastData,
+                                          WindowFunctionLastImpl);
 
 void register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& factory) {
     factory.register_function("dense_rank", create_aggregate_function_dense_rank);
@@ -89,15 +123,16 @@ void register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& fac
     factory.register_function("ntile", create_aggregate_function_ntile);
 }
 
-void register_aggregate_function_window_lead_lag(AggregateFunctionSimpleFactory& factory) {
-    factory.register_function("lead", create_aggregate_function_lead<false>);
-    factory.register_function("lead", create_aggregate_function_lead<true>, true);
-    factory.register_function("lag", create_aggregate_function_lag<false>);
-    factory.register_function("lag", create_aggregate_function_lag<true>, true);
-    factory.register_function("first_value", create_aggregate_function_first<false, false>);
-    factory.register_function("first_value", create_aggregate_function_first<true, false>, true);
-    factory.register_function("last_value", create_aggregate_function_last<false, false>);
-    factory.register_function("last_value", create_aggregate_function_last<true, false>, true);
+void register_aggregate_function_window_lead_lag_first_last(
+        AggregateFunctionSimpleFactory& factory) {
+    factory.register_function("lead", create_aggregate_function_window_lead);
+    factory.register_function("lead", create_aggregate_function_window_lead, true);
+    factory.register_function("lag", create_aggregate_function_window_lag);
+    factory.register_function("lag", create_aggregate_function_window_lag, true);
+    factory.register_function("first_value", create_aggregate_function_window_first);
+    factory.register_function("first_value", create_aggregate_function_window_first, true);
+    factory.register_function("last_value", create_aggregate_function_window_last);
+    factory.register_function("last_value", create_aggregate_function_window_last, true);
 }
 
 } // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.h b/be/src/vec/aggregate_functions/aggregate_function_window.h
index 6e0dba239e..c0d37f8e80 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_window.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_window.h
@@ -22,6 +22,7 @@
 
 #include "factory_helpers.h"
 #include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_reader_first_last.h"
 #include "vec/aggregate_functions/helpers.h"
 #include "vec/columns/column_vector.h"
 #include "vec/data_types/data_type_decimal.h"
@@ -207,60 +208,36 @@ public:
     void deserialize(AggregateDataPtr place, BufferReadable& buf, Arena*) const override {}
 };
 
-struct Value {
+template <typename ColVecType, bool result_is_nullable, bool arg_is_nullable>
+struct FirstLastData
+        : public ReaderFirstAndLastData<ColVecType, result_is_nullable, arg_is_nullable, false> {
 public:
-    bool is_null() const { return _is_null; }
-    void set_null(bool is_null) { _is_null = is_null; }
-    StringRef get_value() const { return _ptr->get_data_at(_offset); }
-
-    void set_value(const IColumn* column, size_t row) {
-        _ptr = column;
-        _offset = row;
-    }
-    void reset() {
-        _is_null = false;
-        _ptr = nullptr;
-        _offset = 0;
-    }
-
-protected:
-    const IColumn* _ptr = nullptr;
-    size_t _offset = 0;
-    bool _is_null;
+    void set_is_null() { this->_data_value.reset(); }
 };
 
-struct CopiedValue : public Value {
+template <typename ColVecType, bool arg_is_nullable>
+struct BaseValue : public Value<ColVecType, arg_is_nullable> {
 public:
-    StringRef get_value() const { return _copied_value; }
-
-    void set_value(const IColumn* column, size_t row) {
-        _copied_value = column->get_data_at(row).to_string();
-    }
-
-private:
-    std::string _copied_value;
+    bool is_null() const { return this->_ptr == nullptr; }
+    // because _ptr pointer to first_argument or third argument, so it's difficult to cast ptr
+    // so here will call virtual function
+    StringRef get_value() const { return this->_ptr->get_data_at(this->_offset); }
 };
 
-template <typename T, bool result_is_nullable, bool is_string, typename StoreType = Value>
-struct LeadAndLagData {
+template <typename ColVecType, bool result_is_nullable, bool arg_is_nullable>
+struct LeadLagData {
 public:
-    bool has_init() const { return _is_init; }
-
-    static constexpr bool nullable = result_is_nullable;
-
-    void set_null_if_need() {
-        if (!_has_value) {
-            this->set_is_null();
-        }
-    }
-
     void reset() {
         _data_value.reset();
         _default_value.reset();
-        _is_init = false;
-        _has_value = false;
+        _is_inited = false;
     }
 
+    bool default_is_null() { return _default_value.is_null(); }
+
+    // here _ptr pointer default column from third
+    void set_value_from_default() { this->_data_value = _default_value; }
+
     void insert_result_into(IColumn& to) const {
         if constexpr (result_is_nullable) {
             if (_data_value.is_null()) {
@@ -277,53 +254,47 @@ public:
         }
     }
 
+    void set_is_null() { this->_data_value.reset(); }
+
     void set_value(const IColumn** columns, size_t pos) {
-        if (columns[0]->is_nullable() &&
-            assert_cast<const ColumnNullable*>(columns[0])->is_null_at(pos)) {
-            _data_value.set_null(true);
-        } else {
-            _data_value.set_value(columns[0], pos);
-            _data_value.set_null(false);
+        if constexpr (arg_is_nullable) {
+            if (assert_cast<const ColumnNullable*>(columns[0])->is_null_at(pos)) {
+                // ptr == nullptr means nullable
+                _data_value.reset();
+                return;
+            }
         }
-        _has_value = true;
+        // here ptr is pointer to nullable column or not null column from first
+        _data_value.set_value(columns[0], pos);
     }
 
-    bool defualt_is_null() { return _default_value.is_null(); }
-
-    void set_is_null() { _data_value.set_null(true); }
-
-    void set_value_from_default() { _data_value = _default_value; }
-
-    bool has_set_value() { return _has_value; }
-
     void check_default(const IColumn* column) {
-        if (!has_init()) {
+        if (!_is_inited) {
             if (is_column_nullable(*column)) {
                 const auto* nullable_column = assert_cast<const ColumnNullable*>(column);
                 if (nullable_column->is_null_at(0)) {
-                    _default_value.set_null(true);
+                    _default_value.reset();
                 }
             } else {
                 _default_value.set_value(column, 0);
             }
-            _is_init = true;
+            _is_inited = true;
         }
     }
 
 private:
-    StoreType _data_value;
-    StoreType _default_value;
-    bool _has_value = false;
-    bool _is_init = false;
+    BaseValue<ColVecType, arg_is_nullable> _data_value;
+    BaseValue<ColVecType, arg_is_nullable> _default_value;
+    bool _is_inited = false;
 };
 
 template <typename Data>
-struct WindowFunctionLeadData : Data {
-    void add_range_single_place(int64_t partition_start, int64_t partition_end, size_t frame_start,
-                                size_t frame_end, const IColumn** columns) {
+struct WindowFunctionLeadImpl : Data {
+    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
+                                int64_t frame_end, const IColumn** columns) {
         this->check_default(columns[2]);
         if (frame_end > partition_end) { //output default value, win end is under partition
-            if (this->defualt_is_null()) {
+            if (this->default_is_null()) {
                 this->set_is_null();
             } else {
                 this->set_value_from_default();
@@ -332,19 +303,17 @@ struct WindowFunctionLeadData : Data {
         }
         this->set_value(columns, frame_end - 1);
     }
-    void add(int64_t row, const IColumn** columns) {
-        LOG(FATAL) << "WindowFunctionLeadData do not support add";
-    }
+
     static const char* name() { return "lead"; }
 };
 
 template <typename Data>
-struct WindowFunctionLagData : Data {
+struct WindowFunctionLagImpl : Data {
     void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
                                 int64_t frame_end, const IColumn** columns) {
         this->check_default(columns[2]);
         if (partition_start >= frame_end) { //[unbound preceding(0), offset preceding(-123)]
-            if (this->defualt_is_null()) {  // win start is beyond partition
+            if (this->default_is_null()) {  // win start is beyond partition
                 this->set_is_null();
             } else {
                 this->set_value_from_default();
@@ -353,14 +322,15 @@ struct WindowFunctionLagData : Data {
         }
         this->set_value(columns, frame_end - 1);
     }
-    void add(int64_t row, const IColumn** columns) {
-        LOG(FATAL) << "WindowFunctionLagData do not support add";
-    }
+
     static const char* name() { return "lag"; }
 };
 
+// TODO: first_value && last_value in some corner case will be core,
+// if need to simply change it, should set them to always nullable insert into null value, and register in cpp maybe be change
+// But it's may be another better way to handle it
 template <typename Data>
-struct WindowFunctionFirstData : Data {
+struct WindowFunctionFirstImpl : Data {
     void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
                                 int64_t frame_end, const IColumn** columns) {
         if (this->has_set_value()) {
@@ -374,61 +344,12 @@ struct WindowFunctionFirstData : Data {
         frame_start = std::max<int64_t>(frame_start, partition_start);
         this->set_value(columns, frame_start);
     }
-    void add(int64_t row, const IColumn** columns) {
-        if (this->has_set_value()) {
-            return;
-        }
-        this->set_value(columns, row);
-    }
-    static const char* name() { return "first_value"; }
-};
-
-template <typename Data>
-struct WindowFunctionFirstNonNullData : Data {
-    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
-                                int64_t frame_end, const IColumn** columns) {
-        if (this->has_set_value()) {
-            return;
-        }
-        if (frame_start < frame_end &&
-            frame_end <= partition_start) { //rewrite last_value when under partition
-            this->set_is_null();            //so no need more judge
-            return;
-        }
-        frame_start = std::max<int64_t>(frame_start, partition_start);
-        frame_end = std::min<int64_t>(frame_end, partition_end);
-        if constexpr (Data::nullable) {
-            this->set_null_if_need();
-            const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
-            for (int i = frame_start; i < frame_end; i++) {
-                if (!nullable_column->is_null_at(i)) {
-                    this->set_value(columns, i);
-                    return;
-                }
-            }
-        } else {
-            this->set_value(columns, frame_start);
-        }
-    }
 
-    void add(int64_t row, const IColumn** columns) {
-        if (this->has_set_value()) {
-            return;
-        }
-        if constexpr (Data::nullable) {
-            this->set_null_if_need();
-            const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
-            if (nullable_column->is_null_at(row)) {
-                return;
-            }
-        }
-        this->set_value(columns, row);
-    }
-    static const char* name() { return "first_non_null_value"; }
+    static const char* name() { return "first_value"; }
 };
 
 template <typename Data>
-struct WindowFunctionLastData : Data {
+struct WindowFunctionLastImpl : Data {
     void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
                                 int64_t frame_end, const IColumn** columns) {
         if ((frame_start < frame_end) &&
@@ -440,48 +361,8 @@ struct WindowFunctionLastData : Data {
         frame_end = std::min<int64_t>(frame_end, partition_end);
         this->set_value(columns, frame_end - 1);
     }
-    void add(int64_t row, const IColumn** columns) { this->set_value(columns, row); }
-    static const char* name() { return "last_value"; }
-};
-
-template <typename Data>
-struct WindowFunctionLastNonNullData : Data {
-    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
-                                int64_t frame_end, const IColumn** columns) {
-        if ((frame_start < frame_end) &&
-            ((frame_end <= partition_start) ||
-             (frame_start >= partition_end))) { //beyond or under partition, set null
-            this->set_is_null();
-            return;
-        }
-        frame_start = std::max<int64_t>(frame_start, partition_start);
-        frame_end = std::min<int64_t>(frame_end, partition_end);
-        if constexpr (Data::nullable) {
-            this->set_null_if_need();
-            const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
-            for (int i = frame_end - 1; i >= frame_start; i--) {
-                if (!nullable_column->is_null_at(i)) {
-                    this->set_value(columns, i);
-                    return;
-                }
-            }
-        } else {
-            this->set_value(columns, frame_end - 1);
-        }
-    }
-
-    void add(int64_t row, const IColumn** columns) {
-        if constexpr (Data::nullable) {
-            this->set_null_if_need();
-            const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
-            if (nullable_column->is_null_at(row)) {
-                return;
-            }
-        }
-        this->set_value(columns, row);
-    }
 
-    static const char* name() { return "last_non_null_value"; }
+    static const char* name() { return "last_value"; }
 };
 
 template <typename Data>
@@ -493,6 +374,7 @@ public:
               _argument_type(argument_types[0]) {}
 
     String get_name() const override { return Data::name(); }
+
     DataTypePtr get_return_type() const override { return _argument_type; }
 
     void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
@@ -510,104 +392,20 @@ public:
 
     void add(AggregateDataPtr place, const IColumn** columns, size_t row_num,
              Arena* arena) const override {
-        this->data(place).add(row_num, columns);
+        LOG(FATAL) << "WindowFunctionLeadLagData do not support add";
     }
     void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena*) const override {
-        LOG(FATAL) << "WindowFunctionData do not support merge";
+        LOG(FATAL) << "WindowFunctionLeadLagData do not support merge";
     }
     void serialize(ConstAggregateDataPtr place, BufferWritable& buf) const override {
-        LOG(FATAL) << "WindowFunctionData do not support serialize";
+        LOG(FATAL) << "WindowFunctionLeadLagData do not support serialize";
     }
     void deserialize(AggregateDataPtr place, BufferReadable& buf, Arena*) const override {
-        LOG(FATAL) << "WindowFunctionData do not support deserialize";
+        LOG(FATAL) << "WindowFunctionLeadLagData do not support deserialize";
     }
 
 private:
     DataTypePtr _argument_type;
 };
 
-template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
-          bool result_is_nullable, bool is_copy = false>
-static IAggregateFunction* create_function_single_value(const String& name,
-                                                        const DataTypes& argument_types,
-                                                        const Array& parameters) {
-    using StoreType = std::conditional_t<is_copy, CopiedValue, Value>;
-
-    assert_arity_at_most<3>(name, argument_types);
-
-    auto type = remove_nullable(argument_types[0]);
-    WhichDataType which(*type);
-
-#define DISPATCH(TYPE)                        \
-    if (which.idx == TypeIndex::TYPE)         \
-        return new AggregateFunctionTemplate< \
-                Data<LeadAndLagData<TYPE, result_is_nullable, false, StoreType>>>(argument_types);
-    FOR_NUMERIC_TYPES(DISPATCH)
-#undef DISPATCH
-
-    if (which.is_decimal()) {
-        return new AggregateFunctionTemplate<
-                Data<LeadAndLagData<Int128, result_is_nullable, false, StoreType>>>(argument_types);
-    }
-    if (which.is_date_or_datetime()) {
-        return new AggregateFunctionTemplate<
-                Data<LeadAndLagData<Int64, result_is_nullable, false, StoreType>>>(argument_types);
-    }
-    if (which.is_date_v2()) {
-        return new AggregateFunctionTemplate<
-                Data<LeadAndLagData<UInt32, result_is_nullable, false, StoreType>>>(argument_types);
-    }
-    if (which.is_date_time_v2()) {
-        return new AggregateFunctionTemplate<
-                Data<LeadAndLagData<UInt64, result_is_nullable, false, StoreType>>>(argument_types);
-    }
-    if (which.is_string_or_fixed_string()) {
-        return new AggregateFunctionTemplate<
-                Data<LeadAndLagData<StringRef, result_is_nullable, true, StoreType>>>(
-                argument_types);
-    }
-    DCHECK(false) << "with unknowed type, failed in  create_aggregate_function_" << name;
-    return nullptr;
-}
-
-template <bool is_nullable, bool is_copy>
-AggregateFunctionPtr create_aggregate_function_first(const std::string& name,
-                                                     const DataTypes& argument_types,
-                                                     const Array& parameters,
-                                                     bool result_is_nullable) {
-    return AggregateFunctionPtr(
-            create_function_single_value<WindowFunctionData, WindowFunctionFirstData, is_nullable,
-                                         is_copy>(name, argument_types, parameters));
-}
-
-template <bool is_nullable, bool is_copy>
-AggregateFunctionPtr create_aggregate_function_first_non_null_value(const std::string& name,
-                                                                    const DataTypes& argument_types,
-                                                                    const Array& parameters,
-                                                                    bool result_is_nullable) {
-    return AggregateFunctionPtr(
-            create_function_single_value<WindowFunctionData, WindowFunctionFirstNonNullData,
-                                         is_nullable, is_copy>(name, argument_types, parameters));
-}
-
-template <bool is_nullable, bool is_copy>
-AggregateFunctionPtr create_aggregate_function_last(const std::string& name,
-                                                    const DataTypes& argument_types,
-                                                    const Array& parameters,
-                                                    bool result_is_nullable) {
-    return AggregateFunctionPtr(
-            create_function_single_value<WindowFunctionData, WindowFunctionLastData, is_nullable,
-                                         is_copy>(name, argument_types, parameters));
-}
-
-template <bool is_nullable, bool is_copy>
-AggregateFunctionPtr create_aggregate_function_last_non_null_value(const std::string& name,
-                                                                   const DataTypes& argument_types,
-                                                                   const Array& parameters,
-                                                                   bool result_is_nullable) {
-    return AggregateFunctionPtr(
-            create_function_single_value<WindowFunctionData, WindowFunctionLastNonNullData,
-                                         is_nullable, is_copy>(name, argument_types, parameters));
-}
-
 } // namespace doris::vectorized
diff --git a/be/src/vec/exec/join/vhash_join_node.cpp b/be/src/vec/exec/join/vhash_join_node.cpp
index c88fbbb683..6cbdcfa53f 100644
--- a/be/src/vec/exec/join/vhash_join_node.cpp
+++ b/be/src/vec/exec/join/vhash_join_node.cpp
@@ -31,14 +31,6 @@
 
 namespace doris::vectorized {
 
-std::variant<std::false_type, std::true_type> static inline make_bool_variant(bool condition) {
-    if (condition) {
-        return std::true_type {};
-    } else {
-        return std::false_type {};
-    }
-}
-
 using ProfileCounter = RuntimeProfile::Counter;
 template <class HashTableContext>
 struct ProcessHashTableBuild {
diff --git a/be/src/vec/utils/template_helpers.hpp b/be/src/vec/utils/template_helpers.hpp
index ebf822513b..187ec7accc 100644
--- a/be/src/vec/utils/template_helpers.hpp
+++ b/be/src/vec/utils/template_helpers.hpp
@@ -18,6 +18,7 @@
 #pragma once
 
 #include <limits>
+#include <variant>
 
 #include "http/http_status.h"
 #include "vec/aggregate_functions/aggregate_function.h"
@@ -53,11 +54,14 @@
     M(BitMap, ColumnBitmap)            \
     M(HLL, ColumnHLL)
 
-#define TYPE_TO_COLUMN_TYPE(M)     \
-    NUMERIC_TYPE_TO_COLUMN_TYPE(M) \
-    DECIMAL_TYPE_TO_COLUMN_TYPE(M) \
-    STRING_TYPE_TO_COLUMN_TYPE(M)  \
-    TIME_TYPE_TO_COLUMN_TYPE(M)    \
+#define TYPE_TO_BASIC_COLUMN_TYPE(M) \
+    NUMERIC_TYPE_TO_COLUMN_TYPE(M)   \
+    DECIMAL_TYPE_TO_COLUMN_TYPE(M)   \
+    STRING_TYPE_TO_COLUMN_TYPE(M)    \
+    TIME_TYPE_TO_COLUMN_TYPE(M)
+
+#define TYPE_TO_COLUMN_TYPE(M)   \
+    TYPE_TO_BASIC_COLUMN_TYPE(M) \
     COMPLEX_TYPE_TO_COLUMN_TYPE(M)
 
 namespace doris::vectorized {
@@ -150,4 +154,12 @@ template <template <bool, bool, bool> typename Reducer>
 using constexpr_3_bool_match =
         constexpr_3_loop_match<bool, false, true, Reducer, constexpr_2_bool_match>;
 
+std::variant<std::false_type, std::true_type> static inline make_bool_variant(bool condition) {
+    if (condition) {
+        return std::true_type {};
+    } else {
+        return std::false_type {};
+    }
+}
+
 } // namespace  doris::vectorized


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org