You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by yi...@apache.org on 2022/07/30 10:39:02 UTC
[doris] branch master updated: [refactor][vectorized] refactor first/last value agg functions (#10661)
This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 1f30e563a7 [refactor][vectorized] refactor first/last value agg functions (#10661)
1f30e563a7 is described below
commit 1f30e563a708cc5c1800f8d469a43f2f011592f2
Author: zhangstar333 <87...@users.noreply.github.com>
AuthorDate: Sat Jul 30 18:38:56 2022 +0800
[refactor][vectorized] refactor first/last value agg functions (#10661)
* refactor first and last
[refactor][vectorized] refactor first/last value agg functions
* add some change
* remove first/last about always nullable
* remove always nullable and register it
* refactor value remove bool null flag
* refactor win first last to ptr and pos
---
be/src/vec/CMakeLists.txt | 1 -
.../aggregate_function_reader.cpp | 22 +-
.../aggregate_function_reader.h | 2 +-
.../aggregate_function_reader_first_last.h | 287 +++++++++++++++++++
.../aggregate_function_simple_factory.cpp | 5 +-
.../aggregate_function_window.cpp | 89 ++++--
.../aggregate_function_window.h | 312 ++++-----------------
be/src/vec/exec/join/vhash_join_node.cpp | 8 -
be/src/vec/utils/template_helpers.hpp | 22 +-
9 files changed, 434 insertions(+), 314 deletions(-)
diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt
index 6be1a9c797..4e30835e4b 100644
--- a/be/src/vec/CMakeLists.txt
+++ b/be/src/vec/CMakeLists.txt
@@ -41,7 +41,6 @@ set(VEC_FILES
aggregate_functions/aggregate_function_group_concat.cpp
aggregate_functions/aggregate_function_percentile_approx.cpp
aggregate_functions/aggregate_function_simple_factory.cpp
- aggregate_functions/aggregate_function_java_udaf.h
aggregate_functions/aggregate_function_orthogonal_bitmap.cpp
columns/column.cpp
columns/column_array.cpp
diff --git a/be/src/vec/aggregate_functions/aggregate_function_reader.cpp b/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
index 3e27f30b85..8a3bea08bd 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
@@ -44,23 +44,19 @@ void register_aggregate_function_replace_reader_load(AggregateFunctionSimpleFact
factory.register_function(name + suffix, creator, nullable);
};
- register_function("replace", AGG_READER_SUFFIX, create_aggregate_function_first<false, true>,
- false);
- register_function("replace", AGG_READER_SUFFIX, create_aggregate_function_first<true, true>,
- true);
- register_function("replace", AGG_LOAD_SUFFIX, create_aggregate_function_last<false, false>,
- false);
- register_function("replace", AGG_LOAD_SUFFIX, create_aggregate_function_last<true, false>,
- true);
+ register_function("replace", AGG_READER_SUFFIX, create_aggregate_function_first<true>, false);
+ register_function("replace", AGG_READER_SUFFIX, create_aggregate_function_first<true>, true);
+ register_function("replace", AGG_LOAD_SUFFIX, create_aggregate_function_last<false>, false);
+ register_function("replace", AGG_LOAD_SUFFIX, create_aggregate_function_last<false>, true);
register_function("replace_if_not_null", AGG_READER_SUFFIX,
- create_aggregate_function_first_non_null_value<false, true>, false);
+ create_aggregate_function_first_non_null_value<true>, false);
register_function("replace_if_not_null", AGG_READER_SUFFIX,
- create_aggregate_function_first_non_null_value<true, true>, true);
+ create_aggregate_function_first_non_null_value<true>, true);
register_function("replace_if_not_null", AGG_LOAD_SUFFIX,
- create_aggregate_function_last_non_null_value<false, false>, false);
+ create_aggregate_function_last_non_null_value<false>, false);
register_function("replace_if_not_null", AGG_LOAD_SUFFIX,
- create_aggregate_function_last_non_null_value<true, false>, true);
+ create_aggregate_function_last_non_null_value<false>, true);
}
-} // namespace doris::vectorized
+} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/aggregate_function_reader.h b/be/src/vec/aggregate_functions/aggregate_function_reader.h
index 86fea6f079..626c06571b 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_reader.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_reader.h
@@ -20,9 +20,9 @@
#include "vec/aggregate_functions/aggregate_function_bitmap.h"
#include "vec/aggregate_functions/aggregate_function_hll_union_agg.h"
#include "vec/aggregate_functions/aggregate_function_min_max.h"
+#include "vec/aggregate_functions/aggregate_function_reader_first_last.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/aggregate_function_sum.h"
-#include "vec/aggregate_functions/aggregate_function_window.h"
namespace doris::vectorized {
diff --git a/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h b/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h
new file mode 100644
index 0000000000..4b7c1e0c98
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_reader_first_last.h
@@ -0,0 +1,287 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "factory_helpers.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/helpers.h"
+#include "vec/columns/column_nullable.h"
+#include "vec/columns/column_vector.h"
+#include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_nullable.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/data_types/data_type_string.h"
+#include "vec/io/io_helper.h"
+#include "vec/utils/template_helpers.hpp"
+
+namespace doris::vectorized {
+
+template <typename ColVecType, bool arg_is_nullable>
+struct Value {
+public:
+ bool is_null() const {
+ if (_ptr == nullptr) {
+ return true;
+ }
+ if constexpr (arg_is_nullable) {
+ return assert_cast<const ColumnNullable*>(_ptr)->is_null_at(_offset);
+ }
+ return false;
+ }
+
+ StringRef get_value() const {
+ if constexpr (arg_is_nullable) {
+ auto* col = assert_cast<const ColumnNullable*>(_ptr);
+ return assert_cast<const ColVecType&>(col->get_nested_column()).get_data_at(_offset);
+ } else {
+ return assert_cast<const ColVecType*>(_ptr)->get_data_at(_offset);
+ }
+ }
+
+ void set_value(const IColumn* column, size_t row) {
+ _ptr = column;
+ _offset = row;
+ }
+
+ void reset() {
+ _ptr = nullptr;
+ _offset = 0;
+ }
+
+protected:
+ const IColumn* _ptr = nullptr;
+ size_t _offset = 0;
+};
+
+template <typename ColVecType, bool arg_is_nullable>
+struct CopiedValue : public Value<ColVecType, arg_is_nullable> {
+public:
+ StringRef get_value() const { return _copied_value; }
+
+ bool is_null() const { return this->_ptr == nullptr; }
+
+ void set_value(const IColumn* column, size_t row) {
+ // here _ptr, maybe null at row, so call reset to set nullptr
+ // But we will use is_null() check first, others have set _ptr column to a meaningless address
+ // because the address have meaningless, only need it to check is nullptr
+ this->_ptr = (IColumn*)0x00000001;
+ if constexpr (arg_is_nullable) {
+ auto* col = assert_cast<const ColumnNullable*>(column);
+ if (col->is_null_at(row)) {
+ this->reset();
+ return;
+ } else {
+ _copied_value = assert_cast<const ColVecType&>(col->get_nested_column())
+ .get_data_at(row)
+ .to_string();
+ }
+ } else {
+ _copied_value = assert_cast<const ColVecType*>(column)->get_data_at(row).to_string();
+ }
+ }
+
+private:
+ std::string _copied_value;
+};
+
+template <typename ColVecType, bool result_is_nullable, bool arg_is_nullable, bool is_copy>
+struct ReaderFirstAndLastData {
+public:
+ using StoreType = std::conditional_t<is_copy, CopiedValue<ColVecType, arg_is_nullable>,
+ Value<ColVecType, arg_is_nullable>>;
+ static constexpr bool nullable = arg_is_nullable;
+
+ void reset() {
+ _data_value.reset();
+ _has_value = false;
+ }
+
+ void insert_result_into(IColumn& to) const {
+ if constexpr (result_is_nullable) {
+ if (_data_value.is_null()) { //_ptr == nullptr || null data at row
+ auto& col = assert_cast<ColumnNullable&>(to);
+ col.insert_default();
+ } else {
+ auto& col = assert_cast<ColumnNullable&>(to);
+ //get_value will never get null value
+ const StringRef& value = _data_value.get_value();
+ col.get_null_map_data().push_back(0);
+ assert_cast<ColVecType&>(col.get_nested_column())
+ .insert_data(value.data, value.size);
+ }
+ } else {
+ const StringRef& value = _data_value.get_value();
+ assert_cast<ColVecType&>(to).insert_data(value.data, value.size);
+ }
+ }
+
+ // here not check the columns[0] is null at the row,
+ // but it is need to check in other
+ void set_value(const IColumn** columns, size_t pos) {
+ _data_value.set_value(columns[0], pos);
+ _has_value = true;
+ }
+
+ bool has_set_value() { return _has_value; }
+
+protected:
+ StoreType _data_value;
+ bool _has_value = false;
+};
+
+template <typename Data>
+struct ReaderFunctionFirstData : Data {
+ void add(int64_t row, const IColumn** columns) {
+ if (this->has_set_value()) {
+ return;
+ }
+ this->set_value(columns, row);
+ }
+ static const char* name() { return "first_value"; }
+};
+
+template <typename Data>
+struct ReaderFunctionFirstNonNullData : Data {
+ void add(int64_t row, const IColumn** columns) {
+ if (this->has_set_value()) {
+ return;
+ }
+ if constexpr (Data::nullable) {
+ const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
+ if (nullable_column->is_null_at(row)) {
+ return;
+ }
+ }
+ this->set_value(columns, row);
+ }
+ static const char* name() { return "first_non_null_value"; }
+};
+
+template <typename Data>
+struct ReaderFunctionLastData : Data {
+ void add(int64_t row, const IColumn** columns) { this->set_value(columns, row); }
+ static const char* name() { return "last_value"; }
+};
+
+template <typename Data>
+struct ReaderFunctionLastNonNullData : Data {
+ void add(int64_t row, const IColumn** columns) {
+ if constexpr (Data::nullable) {
+ const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
+ if (nullable_column->is_null_at(row)) {
+ return;
+ }
+ }
+ this->set_value(columns, row);
+ }
+
+ static const char* name() { return "last_non_null_value"; }
+};
+
+template <typename Data>
+class ReaderFunctionData final
+ : public IAggregateFunctionDataHelper<Data, ReaderFunctionData<Data>> {
+public:
+ ReaderFunctionData(const DataTypes& argument_types)
+ : IAggregateFunctionDataHelper<Data, ReaderFunctionData<Data>>(argument_types, {}),
+ _argument_type(argument_types[0]) {}
+
+ String get_name() const override { return Data::name(); }
+
+ DataTypePtr get_return_type() const override { return _argument_type; }
+
+ void insert_result_into(ConstAggregateDataPtr place, IColumn& to) const override {
+ this->data(place).insert_result_into(to);
+ }
+
+ void add(AggregateDataPtr place, const IColumn** columns, size_t row_num,
+ Arena* arena) const override {
+ this->data(place).add(row_num, columns);
+ }
+
+ void reset(AggregateDataPtr place) const override { this->data(place).reset(); }
+
+ void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
+ int64_t frame_end, AggregateDataPtr place, const IColumn** columns,
+ Arena* arena) const override {
+ LOG(FATAL) << "ReaderFunctionData do not support add_range_single_place";
+ }
+ void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena*) const override {
+ LOG(FATAL) << "ReaderFunctionData do not support merge";
+ }
+ void serialize(ConstAggregateDataPtr place, BufferWritable& buf) const override {
+ LOG(FATAL) << "ReaderFunctionData do not support serialize";
+ }
+ void deserialize(AggregateDataPtr place, BufferReadable& buf, Arena*) const override {
+ LOG(FATAL) << "ReaderFunctionData do not support deserialize";
+ }
+
+private:
+ DataTypePtr _argument_type;
+};
+
+template <template <typename> class AggregateFunctionTemplate, template <typename> class Impl,
+ bool result_is_nullable, bool arg_is_nullable, bool is_copy = false>
+static IAggregateFunction* create_function_single_value(const String& name,
+ const DataTypes& argument_types,
+ const Array& parameters) {
+ auto type = remove_nullable(argument_types[0]);
+ WhichDataType which(*type);
+
+#define DISPATCH(TYPE, COLUMN_TYPE) \
+ if (which.idx == TypeIndex::TYPE) \
+ return new AggregateFunctionTemplate<Impl<ReaderFirstAndLastData< \
+ COLUMN_TYPE, result_is_nullable, arg_is_nullable, is_copy>>>(argument_types);
+ TYPE_TO_COLUMN_TYPE(DISPATCH)
+#undef DISPATCH
+
+ LOG(FATAL) << "with unknowed type, failed in create_aggregate_function_" << name
+ << " and type is: " << argument_types[0]->get_name();
+ return nullptr;
+}
+
+#define CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(CREATE_FUNCTION_NAME, FUNCTION_DATA) \
+ template <bool is_copy> \
+ AggregateFunctionPtr CREATE_FUNCTION_NAME(const std::string& name, \
+ const DataTypes& argument_types, \
+ const Array& parameters, bool result_is_nullable) { \
+ const bool arg_is_nullable = argument_types[0]->is_nullable(); \
+ AggregateFunctionPtr res = nullptr; \
+ std::visit( \
+ [&](auto result_is_nullable, auto arg_is_nullable) { \
+ res = AggregateFunctionPtr( \
+ create_function_single_value<ReaderFunctionData, FUNCTION_DATA, \
+ result_is_nullable, arg_is_nullable, \
+ is_copy>(name, argument_types, \
+ parameters)); \
+ }, \
+ make_bool_variant(result_is_nullable), make_bool_variant(arg_is_nullable)); \
+ if (!res) { \
+ LOG(WARNING) << " failed in create_aggregate_function_" << name \
+ << " and type is: " << argument_types[0]->get_name(); \
+ } \
+ return res; \
+ }
+
+CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_first, ReaderFunctionFirstData);
+CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_first_non_null_value,
+ ReaderFunctionFirstNonNullData);
+CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_last, ReaderFunctionLastData);
+CREATE_READER_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_last_non_null_value,
+ ReaderFunctionLastNonNullData);
+} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index 73779f8ffa..79d21985bc 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -39,7 +39,8 @@ void register_aggregate_function_HLL_union_agg(AggregateFunctionSimpleFactory& f
void register_aggregate_function_uniq(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_bitmap(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& factory);
-void register_aggregate_function_window_lead_lag(AggregateFunctionSimpleFactory& factory);
+void register_aggregate_function_window_lead_lag_first_last(
+ AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_stddev_variance_pop(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_stddev_variance_samp(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory);
@@ -81,7 +82,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_stddev_variance_samp(instance);
register_aggregate_function_replace_reader_load(instance);
- register_aggregate_function_window_lead_lag(instance);
+ register_aggregate_function_window_lead_lag_first_last(instance);
register_aggregate_function_HLL_union_agg(instance);
register_aggregate_function_percentile_approx(instance);
});
diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.cpp b/be/src/vec/aggregate_functions/aggregate_function_window.cpp
index 1a342d805a..02b283ab2d 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_window.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_window.cpp
@@ -22,7 +22,7 @@
#include "common/logging.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
-#include "vec/aggregate_functions/factory_helpers.h"
+#include "vec/utils/template_helpers.hpp"
namespace doris::vectorized {
@@ -62,25 +62,59 @@ AggregateFunctionPtr create_aggregate_function_ntile(const std::string& name,
return std::make_shared<WindowFunctionNTile>(argument_types, parameters);
}
-template <bool is_nullable>
-AggregateFunctionPtr create_aggregate_function_lag(const std::string& name,
- const DataTypes& argument_types,
- const Array& parameters,
- const bool result_is_nullable) {
- return AggregateFunctionPtr(
- create_function_single_value<WindowFunctionData, WindowFunctionLagData, is_nullable>(
- name, argument_types, parameters));
+template <template <typename> class AggregateFunctionTemplate,
+ template <typename ColVecType, bool, bool> class Data, template <typename> class Impl,
+ bool result_is_nullable, bool arg_is_nullable>
+static IAggregateFunction* create_function_lead_lag_first_last(const String& name,
+ const DataTypes& argument_types,
+ const Array& parameters) {
+ auto type = remove_nullable(argument_types[0]);
+ WhichDataType which(*type);
+
+#define DISPATCH(TYPE, COLUMN_TYPE) \
+ if (which.idx == TypeIndex::TYPE) \
+ return new AggregateFunctionTemplate< \
+ Impl<Data<COLUMN_TYPE, result_is_nullable, arg_is_nullable>>>(argument_types);
+ TYPE_TO_BASIC_COLUMN_TYPE(DISPATCH)
+#undef DISPATCH
+
+ LOG(FATAL) << "with unknowed type, failed in create_aggregate_function_" << name
+ << " and type is: " << argument_types[0]->get_name();
+ return nullptr;
}
-template <bool is_nullable>
-AggregateFunctionPtr create_aggregate_function_lead(const std::string& name,
- const DataTypes& argument_types,
- const Array& parameters,
- const bool result_is_nullable) {
- return AggregateFunctionPtr(
- create_function_single_value<WindowFunctionData, WindowFunctionLeadData, is_nullable>(
- name, argument_types, parameters));
-}
+#define CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(CREATE_FUNCTION_NAME, FUNCTION_DATA, \
+ FUNCTION_IMPL) \
+ AggregateFunctionPtr CREATE_FUNCTION_NAME( \
+ const std::string& name, const DataTypes& argument_types, const Array& parameters, \
+ const bool result_is_nullable) { \
+ const bool arg_is_nullable = argument_types[0]->is_nullable(); \
+ AggregateFunctionPtr res = nullptr; \
+ \
+ std::visit( \
+ [&](auto result_is_nullable, auto arg_is_nullable) { \
+ res = AggregateFunctionPtr( \
+ create_function_lead_lag_first_last<WindowFunctionData, FUNCTION_DATA, \
+ FUNCTION_IMPL, result_is_nullable, \
+ arg_is_nullable>( \
+ name, argument_types, parameters)); \
+ }, \
+ make_bool_variant(result_is_nullable), make_bool_variant(arg_is_nullable)); \
+ if (!res) { \
+ LOG(WARNING) << " failed in create_aggregate_function_" << name \
+ << " and type is: " << argument_types[0]->get_name(); \
+ } \
+ return res; \
+ }
+
+CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_lag, LeadLagData,
+ WindowFunctionLagImpl);
+CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_lead, LeadLagData,
+ WindowFunctionLeadImpl);
+CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_first, FirstLastData,
+ WindowFunctionFirstImpl);
+CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_last, FirstLastData,
+ WindowFunctionLastImpl);
void register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& factory) {
factory.register_function("dense_rank", create_aggregate_function_dense_rank);
@@ -89,15 +123,16 @@ void register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& fac
factory.register_function("ntile", create_aggregate_function_ntile);
}
-void register_aggregate_function_window_lead_lag(AggregateFunctionSimpleFactory& factory) {
- factory.register_function("lead", create_aggregate_function_lead<false>);
- factory.register_function("lead", create_aggregate_function_lead<true>, true);
- factory.register_function("lag", create_aggregate_function_lag<false>);
- factory.register_function("lag", create_aggregate_function_lag<true>, true);
- factory.register_function("first_value", create_aggregate_function_first<false, false>);
- factory.register_function("first_value", create_aggregate_function_first<true, false>, true);
- factory.register_function("last_value", create_aggregate_function_last<false, false>);
- factory.register_function("last_value", create_aggregate_function_last<true, false>, true);
+void register_aggregate_function_window_lead_lag_first_last(
+ AggregateFunctionSimpleFactory& factory) {
+ factory.register_function("lead", create_aggregate_function_window_lead);
+ factory.register_function("lead", create_aggregate_function_window_lead, true);
+ factory.register_function("lag", create_aggregate_function_window_lag);
+ factory.register_function("lag", create_aggregate_function_window_lag, true);
+ factory.register_function("first_value", create_aggregate_function_window_first);
+ factory.register_function("first_value", create_aggregate_function_window_first, true);
+ factory.register_function("last_value", create_aggregate_function_window_last);
+ factory.register_function("last_value", create_aggregate_function_window_last, true);
}
} // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.h b/be/src/vec/aggregate_functions/aggregate_function_window.h
index 6e0dba239e..c0d37f8e80 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_window.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_window.h
@@ -22,6 +22,7 @@
#include "factory_helpers.h"
#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_reader_first_last.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/columns/column_vector.h"
#include "vec/data_types/data_type_decimal.h"
@@ -207,60 +208,36 @@ public:
void deserialize(AggregateDataPtr place, BufferReadable& buf, Arena*) const override {}
};
-struct Value {
+template <typename ColVecType, bool result_is_nullable, bool arg_is_nullable>
+struct FirstLastData
+ : public ReaderFirstAndLastData<ColVecType, result_is_nullable, arg_is_nullable, false> {
public:
- bool is_null() const { return _is_null; }
- void set_null(bool is_null) { _is_null = is_null; }
- StringRef get_value() const { return _ptr->get_data_at(_offset); }
-
- void set_value(const IColumn* column, size_t row) {
- _ptr = column;
- _offset = row;
- }
- void reset() {
- _is_null = false;
- _ptr = nullptr;
- _offset = 0;
- }
-
-protected:
- const IColumn* _ptr = nullptr;
- size_t _offset = 0;
- bool _is_null;
+ void set_is_null() { this->_data_value.reset(); }
};
-struct CopiedValue : public Value {
+template <typename ColVecType, bool arg_is_nullable>
+struct BaseValue : public Value<ColVecType, arg_is_nullable> {
public:
- StringRef get_value() const { return _copied_value; }
-
- void set_value(const IColumn* column, size_t row) {
- _copied_value = column->get_data_at(row).to_string();
- }
-
-private:
- std::string _copied_value;
+ bool is_null() const { return this->_ptr == nullptr; }
+ // because _ptr pointer to first_argument or third argument, so it's difficult to cast ptr
+ // so here will call virtual function
+ StringRef get_value() const { return this->_ptr->get_data_at(this->_offset); }
};
-template <typename T, bool result_is_nullable, bool is_string, typename StoreType = Value>
-struct LeadAndLagData {
+template <typename ColVecType, bool result_is_nullable, bool arg_is_nullable>
+struct LeadLagData {
public:
- bool has_init() const { return _is_init; }
-
- static constexpr bool nullable = result_is_nullable;
-
- void set_null_if_need() {
- if (!_has_value) {
- this->set_is_null();
- }
- }
-
void reset() {
_data_value.reset();
_default_value.reset();
- _is_init = false;
- _has_value = false;
+ _is_inited = false;
}
+ bool default_is_null() { return _default_value.is_null(); }
+
+ // here _ptr pointer default column from third
+ void set_value_from_default() { this->_data_value = _default_value; }
+
void insert_result_into(IColumn& to) const {
if constexpr (result_is_nullable) {
if (_data_value.is_null()) {
@@ -277,53 +254,47 @@ public:
}
}
+ void set_is_null() { this->_data_value.reset(); }
+
void set_value(const IColumn** columns, size_t pos) {
- if (columns[0]->is_nullable() &&
- assert_cast<const ColumnNullable*>(columns[0])->is_null_at(pos)) {
- _data_value.set_null(true);
- } else {
- _data_value.set_value(columns[0], pos);
- _data_value.set_null(false);
+ if constexpr (arg_is_nullable) {
+ if (assert_cast<const ColumnNullable*>(columns[0])->is_null_at(pos)) {
+ // ptr == nullptr means nullable
+ _data_value.reset();
+ return;
+ }
}
- _has_value = true;
+ // here ptr is pointer to nullable column or not null column from first
+ _data_value.set_value(columns[0], pos);
}
- bool defualt_is_null() { return _default_value.is_null(); }
-
- void set_is_null() { _data_value.set_null(true); }
-
- void set_value_from_default() { _data_value = _default_value; }
-
- bool has_set_value() { return _has_value; }
-
void check_default(const IColumn* column) {
- if (!has_init()) {
+ if (!_is_inited) {
if (is_column_nullable(*column)) {
const auto* nullable_column = assert_cast<const ColumnNullable*>(column);
if (nullable_column->is_null_at(0)) {
- _default_value.set_null(true);
+ _default_value.reset();
}
} else {
_default_value.set_value(column, 0);
}
- _is_init = true;
+ _is_inited = true;
}
}
private:
- StoreType _data_value;
- StoreType _default_value;
- bool _has_value = false;
- bool _is_init = false;
+ BaseValue<ColVecType, arg_is_nullable> _data_value;
+ BaseValue<ColVecType, arg_is_nullable> _default_value;
+ bool _is_inited = false;
};
template <typename Data>
-struct WindowFunctionLeadData : Data {
- void add_range_single_place(int64_t partition_start, int64_t partition_end, size_t frame_start,
- size_t frame_end, const IColumn** columns) {
+struct WindowFunctionLeadImpl : Data {
+ void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
+ int64_t frame_end, const IColumn** columns) {
this->check_default(columns[2]);
if (frame_end > partition_end) { //output default value, win end is under partition
- if (this->defualt_is_null()) {
+ if (this->default_is_null()) {
this->set_is_null();
} else {
this->set_value_from_default();
@@ -332,19 +303,17 @@ struct WindowFunctionLeadData : Data {
}
this->set_value(columns, frame_end - 1);
}
- void add(int64_t row, const IColumn** columns) {
- LOG(FATAL) << "WindowFunctionLeadData do not support add";
- }
+
static const char* name() { return "lead"; }
};
template <typename Data>
-struct WindowFunctionLagData : Data {
+struct WindowFunctionLagImpl : Data {
void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
int64_t frame_end, const IColumn** columns) {
this->check_default(columns[2]);
if (partition_start >= frame_end) { //[unbound preceding(0), offset preceding(-123)]
- if (this->defualt_is_null()) { // win start is beyond partition
+ if (this->default_is_null()) { // win start is beyond partition
this->set_is_null();
} else {
this->set_value_from_default();
@@ -353,14 +322,15 @@ struct WindowFunctionLagData : Data {
}
this->set_value(columns, frame_end - 1);
}
- void add(int64_t row, const IColumn** columns) {
- LOG(FATAL) << "WindowFunctionLagData do not support add";
- }
+
static const char* name() { return "lag"; }
};
+// TODO: first_value && last_value in some corner case will be core,
+// if need to simply change it, should set them to always nullable insert into null value, and register in cpp maybe be change
+// But it's may be another better way to handle it
template <typename Data>
-struct WindowFunctionFirstData : Data {
+struct WindowFunctionFirstImpl : Data {
void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
int64_t frame_end, const IColumn** columns) {
if (this->has_set_value()) {
@@ -374,61 +344,12 @@ struct WindowFunctionFirstData : Data {
frame_start = std::max<int64_t>(frame_start, partition_start);
this->set_value(columns, frame_start);
}
- void add(int64_t row, const IColumn** columns) {
- if (this->has_set_value()) {
- return;
- }
- this->set_value(columns, row);
- }
- static const char* name() { return "first_value"; }
-};
-
-template <typename Data>
-struct WindowFunctionFirstNonNullData : Data {
- void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
- int64_t frame_end, const IColumn** columns) {
- if (this->has_set_value()) {
- return;
- }
- if (frame_start < frame_end &&
- frame_end <= partition_start) { //rewrite last_value when under partition
- this->set_is_null(); //so no need more judge
- return;
- }
- frame_start = std::max<int64_t>(frame_start, partition_start);
- frame_end = std::min<int64_t>(frame_end, partition_end);
- if constexpr (Data::nullable) {
- this->set_null_if_need();
- const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
- for (int i = frame_start; i < frame_end; i++) {
- if (!nullable_column->is_null_at(i)) {
- this->set_value(columns, i);
- return;
- }
- }
- } else {
- this->set_value(columns, frame_start);
- }
- }
- void add(int64_t row, const IColumn** columns) {
- if (this->has_set_value()) {
- return;
- }
- if constexpr (Data::nullable) {
- this->set_null_if_need();
- const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
- if (nullable_column->is_null_at(row)) {
- return;
- }
- }
- this->set_value(columns, row);
- }
- static const char* name() { return "first_non_null_value"; }
+ static const char* name() { return "first_value"; }
};
template <typename Data>
-struct WindowFunctionLastData : Data {
+struct WindowFunctionLastImpl : Data {
void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
int64_t frame_end, const IColumn** columns) {
if ((frame_start < frame_end) &&
@@ -440,48 +361,8 @@ struct WindowFunctionLastData : Data {
frame_end = std::min<int64_t>(frame_end, partition_end);
this->set_value(columns, frame_end - 1);
}
- void add(int64_t row, const IColumn** columns) { this->set_value(columns, row); }
- static const char* name() { return "last_value"; }
-};
-
-template <typename Data>
-struct WindowFunctionLastNonNullData : Data {
- void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
- int64_t frame_end, const IColumn** columns) {
- if ((frame_start < frame_end) &&
- ((frame_end <= partition_start) ||
- (frame_start >= partition_end))) { //beyond or under partition, set null
- this->set_is_null();
- return;
- }
- frame_start = std::max<int64_t>(frame_start, partition_start);
- frame_end = std::min<int64_t>(frame_end, partition_end);
- if constexpr (Data::nullable) {
- this->set_null_if_need();
- const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
- for (int i = frame_end - 1; i >= frame_start; i--) {
- if (!nullable_column->is_null_at(i)) {
- this->set_value(columns, i);
- return;
- }
- }
- } else {
- this->set_value(columns, frame_end - 1);
- }
- }
-
- void add(int64_t row, const IColumn** columns) {
- if constexpr (Data::nullable) {
- this->set_null_if_need();
- const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
- if (nullable_column->is_null_at(row)) {
- return;
- }
- }
- this->set_value(columns, row);
- }
- static const char* name() { return "last_non_null_value"; }
+ static const char* name() { return "last_value"; }
};
template <typename Data>
@@ -493,6 +374,7 @@ public:
_argument_type(argument_types[0]) {}
String get_name() const override { return Data::name(); }
+
DataTypePtr get_return_type() const override { return _argument_type; }
void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
@@ -510,104 +392,20 @@ public:
void add(AggregateDataPtr place, const IColumn** columns, size_t row_num,
Arena* arena) const override {
- this->data(place).add(row_num, columns);
+ LOG(FATAL) << "WindowFunctionLeadLagData do not support add";
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena*) const override {
- LOG(FATAL) << "WindowFunctionData do not support merge";
+ LOG(FATAL) << "WindowFunctionLeadLagData do not support merge";
}
void serialize(ConstAggregateDataPtr place, BufferWritable& buf) const override {
- LOG(FATAL) << "WindowFunctionData do not support serialize";
+ LOG(FATAL) << "WindowFunctionLeadLagData do not support serialize";
}
void deserialize(AggregateDataPtr place, BufferReadable& buf, Arena*) const override {
- LOG(FATAL) << "WindowFunctionData do not support deserialize";
+ LOG(FATAL) << "WindowFunctionLeadLagData do not support deserialize";
}
private:
DataTypePtr _argument_type;
};
-template <template <typename> class AggregateFunctionTemplate, template <typename> class Data,
- bool result_is_nullable, bool is_copy = false>
-static IAggregateFunction* create_function_single_value(const String& name,
- const DataTypes& argument_types,
- const Array& parameters) {
- using StoreType = std::conditional_t<is_copy, CopiedValue, Value>;
-
- assert_arity_at_most<3>(name, argument_types);
-
- auto type = remove_nullable(argument_types[0]);
- WhichDataType which(*type);
-
-#define DISPATCH(TYPE) \
- if (which.idx == TypeIndex::TYPE) \
- return new AggregateFunctionTemplate< \
- Data<LeadAndLagData<TYPE, result_is_nullable, false, StoreType>>>(argument_types);
- FOR_NUMERIC_TYPES(DISPATCH)
-#undef DISPATCH
-
- if (which.is_decimal()) {
- return new AggregateFunctionTemplate<
- Data<LeadAndLagData<Int128, result_is_nullable, false, StoreType>>>(argument_types);
- }
- if (which.is_date_or_datetime()) {
- return new AggregateFunctionTemplate<
- Data<LeadAndLagData<Int64, result_is_nullable, false, StoreType>>>(argument_types);
- }
- if (which.is_date_v2()) {
- return new AggregateFunctionTemplate<
- Data<LeadAndLagData<UInt32, result_is_nullable, false, StoreType>>>(argument_types);
- }
- if (which.is_date_time_v2()) {
- return new AggregateFunctionTemplate<
- Data<LeadAndLagData<UInt64, result_is_nullable, false, StoreType>>>(argument_types);
- }
- if (which.is_string_or_fixed_string()) {
- return new AggregateFunctionTemplate<
- Data<LeadAndLagData<StringRef, result_is_nullable, true, StoreType>>>(
- argument_types);
- }
- DCHECK(false) << "with unknowed type, failed in create_aggregate_function_" << name;
- return nullptr;
-}
-
-template <bool is_nullable, bool is_copy>
-AggregateFunctionPtr create_aggregate_function_first(const std::string& name,
- const DataTypes& argument_types,
- const Array& parameters,
- bool result_is_nullable) {
- return AggregateFunctionPtr(
- create_function_single_value<WindowFunctionData, WindowFunctionFirstData, is_nullable,
- is_copy>(name, argument_types, parameters));
-}
-
-template <bool is_nullable, bool is_copy>
-AggregateFunctionPtr create_aggregate_function_first_non_null_value(const std::string& name,
- const DataTypes& argument_types,
- const Array& parameters,
- bool result_is_nullable) {
- return AggregateFunctionPtr(
- create_function_single_value<WindowFunctionData, WindowFunctionFirstNonNullData,
- is_nullable, is_copy>(name, argument_types, parameters));
-}
-
-template <bool is_nullable, bool is_copy>
-AggregateFunctionPtr create_aggregate_function_last(const std::string& name,
- const DataTypes& argument_types,
- const Array& parameters,
- bool result_is_nullable) {
- return AggregateFunctionPtr(
- create_function_single_value<WindowFunctionData, WindowFunctionLastData, is_nullable,
- is_copy>(name, argument_types, parameters));
-}
-
-template <bool is_nullable, bool is_copy>
-AggregateFunctionPtr create_aggregate_function_last_non_null_value(const std::string& name,
- const DataTypes& argument_types,
- const Array& parameters,
- bool result_is_nullable) {
- return AggregateFunctionPtr(
- create_function_single_value<WindowFunctionData, WindowFunctionLastNonNullData,
- is_nullable, is_copy>(name, argument_types, parameters));
-}
-
} // namespace doris::vectorized
diff --git a/be/src/vec/exec/join/vhash_join_node.cpp b/be/src/vec/exec/join/vhash_join_node.cpp
index c88fbbb683..6cbdcfa53f 100644
--- a/be/src/vec/exec/join/vhash_join_node.cpp
+++ b/be/src/vec/exec/join/vhash_join_node.cpp
@@ -31,14 +31,6 @@
namespace doris::vectorized {
-std::variant<std::false_type, std::true_type> static inline make_bool_variant(bool condition) {
- if (condition) {
- return std::true_type {};
- } else {
- return std::false_type {};
- }
-}
-
using ProfileCounter = RuntimeProfile::Counter;
template <class HashTableContext>
struct ProcessHashTableBuild {
diff --git a/be/src/vec/utils/template_helpers.hpp b/be/src/vec/utils/template_helpers.hpp
index ebf822513b..187ec7accc 100644
--- a/be/src/vec/utils/template_helpers.hpp
+++ b/be/src/vec/utils/template_helpers.hpp
@@ -18,6 +18,7 @@
#pragma once
#include <limits>
+#include <variant>
#include "http/http_status.h"
#include "vec/aggregate_functions/aggregate_function.h"
@@ -53,11 +54,14 @@
M(BitMap, ColumnBitmap) \
M(HLL, ColumnHLL)
-#define TYPE_TO_COLUMN_TYPE(M) \
- NUMERIC_TYPE_TO_COLUMN_TYPE(M) \
- DECIMAL_TYPE_TO_COLUMN_TYPE(M) \
- STRING_TYPE_TO_COLUMN_TYPE(M) \
- TIME_TYPE_TO_COLUMN_TYPE(M) \
+#define TYPE_TO_BASIC_COLUMN_TYPE(M) \
+ NUMERIC_TYPE_TO_COLUMN_TYPE(M) \
+ DECIMAL_TYPE_TO_COLUMN_TYPE(M) \
+ STRING_TYPE_TO_COLUMN_TYPE(M) \
+ TIME_TYPE_TO_COLUMN_TYPE(M)
+
+#define TYPE_TO_COLUMN_TYPE(M) \
+ TYPE_TO_BASIC_COLUMN_TYPE(M) \
COMPLEX_TYPE_TO_COLUMN_TYPE(M)
namespace doris::vectorized {
@@ -150,4 +154,12 @@ template <template <bool, bool, bool> typename Reducer>
using constexpr_3_bool_match =
constexpr_3_loop_match<bool, false, true, Reducer, constexpr_2_bool_match>;
+std::variant<std::false_type, std::true_type> static inline make_bool_variant(bool condition) {
+ if (condition) {
+ return std::true_type {};
+ } else {
+ return std::false_type {};
+ }
+}
+
} // namespace doris::vectorized
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org