You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by yi...@apache.org on 2022/05/07 09:17:00 UTC

[incubator-doris] branch master updated: [BUG][Vectorized] fix `replace_if_not_null` in vectorized compaction (#9376)

This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 49890ce9aa [BUG][Vectorized] fix `replace_if_not_null` in vectorized compaction (#9376)
49890ce9aa is described below

commit 49890ce9aa0726da17c211da94265bd8e8065c36
Author: Gabriel <ga...@gmail.com>
AuthorDate: Sat May 7 17:16:54 2022 +0800

    [BUG][Vectorized] fix `replace_if_not_null` in vectorized compaction (#9376)
---
 .../aggregate_function_reader.cpp                  |   8 +-
 .../aggregate_function_window.h                    | 122 ++++++++++++++++++++-
 2 files changed, 122 insertions(+), 8 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function_reader.cpp b/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
index 6b49f7390e..ef3d5c2375 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_reader.cpp
@@ -53,9 +53,13 @@ void register_aggregate_function_replace_reader_load(AggregateFunctionSimpleFact
     register_function("replace", AGG_LOAD_SUFFIX, create_aggregate_function_last<true, true>, true);
 
     register_function("replace_if_not_null", AGG_READER_SUFFIX,
-                      create_aggregate_function_first<false, true>, false);
+                      create_aggregate_function_first_non_null_value<false, true>, false);
+    register_function("replace_if_not_null", AGG_READER_SUFFIX,
+                      create_aggregate_function_first_non_null_value<true, true>, true);
+    register_function("replace_if_not_null", AGG_LOAD_SUFFIX,
+                      create_aggregate_function_last_non_null_value<false, true>, false);
     register_function("replace_if_not_null", AGG_LOAD_SUFFIX,
-                      create_aggregate_function_last<false, true>, false);
+                      create_aggregate_function_last_non_null_value<true, true>, true);
 }
 
 } // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.h b/be/src/vec/aggregate_functions/aggregate_function_window.h
index 8952ef2b44..33b73218c9 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_window.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_window.h
@@ -24,6 +24,7 @@
 #include "vec/aggregate_functions/helpers.h"
 #include "vec/columns/column_vector.h"
 #include "vec/data_types/data_type_decimal.h"
+#include "vec/data_types/data_type_nullable.h"
 #include "vec/data_types/data_type_number.h"
 #include "vec/data_types/data_type_string.h"
 #include "vec/io/io_helper.h"
@@ -187,6 +188,14 @@ struct LeadAndLagData {
 public:
     bool has_init() const { return _is_init; }
 
+    static constexpr bool nullable = is_nullable;
+
+    void set_null_if_need() {
+        if (!_has_value) {
+            this->set_is_null();
+        }
+    }
+
     void reset() {
         _data_value.reset();
         _default_value.reset();
@@ -223,7 +232,7 @@ public:
 
     void set_value(const IColumn** columns, int64_t pos) {
         if constexpr (is_nullable) {
-            const auto* nullable_column = check_and_get_column<ColumnNullable>(columns[0]);
+            const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
             if (nullable_column && nullable_column->is_null_at(pos)) {
                 _data_value.set_null(true);
                 _has_value = true;
@@ -259,7 +268,7 @@ public:
     void check_default(const IColumn* column) {
         if (!has_init()) {
             if (is_column_nullable(*column)) {
-                const auto* nullable_column = check_and_get_column<ColumnNullable>(column);
+                const auto* nullable_column = assert_cast<const ColumnNullable*>(column);
                 if (nullable_column->is_null_at(0)) {
                     _default_value.set_null(true);
                 }
@@ -348,6 +357,50 @@ struct WindowFunctionFirstData : Data {
     static const char* name() { return "first_value"; }
 };
 
+template <typename Data>
+struct WindowFunctionFirstNonNullData : Data {
+    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
+                                int64_t frame_end, const IColumn** columns) {
+        if (this->has_set_value()) {
+            return;
+        }
+        if (frame_start < frame_end &&
+            frame_end <= partition_start) { //rewrite last_value when under partition
+            this->set_is_null();            //so no need more judge
+            return;
+        }
+        frame_start = std::max<int64_t>(frame_start, partition_start);
+        frame_end = std::min<int64_t>(frame_end, partition_end);
+        if constexpr (Data::nullable) {
+            this->set_null_if_need();
+            const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
+            for (int i = frame_start; i < frame_end; i++) {
+                if (!nullable_column->is_null_at(i)) {
+                    this->set_value(columns, i);
+                    return;
+                }
+            }
+        } else {
+            this->set_value(columns, frame_start);
+        }
+    }
+
+    void add(int64_t row, const IColumn** columns) {
+        if (this->has_set_value()) {
+            return;
+        }
+        if constexpr (Data::nullable) {
+            this->set_null_if_need();
+            const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
+            if (nullable_column->is_null_at(row)) {
+                return;
+            }
+        }
+        this->set_value(columns, row);
+    }
+    static const char* name() { return "first_non_null_value"; }
+};
+
 template <typename Data>
 struct WindowFunctionLastData : Data {
     void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
@@ -365,6 +418,46 @@ struct WindowFunctionLastData : Data {
     static const char* name() { return "last_value"; }
 };
 
+template <typename Data>
+struct WindowFunctionLastNonNullData : Data {
+    void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start,
+                                int64_t frame_end, const IColumn** columns) {
+        if ((frame_start < frame_end) &&
+            ((frame_end <= partition_start) ||
+             (frame_start >= partition_end))) { //beyond or under partition, set null
+            this->set_is_null();
+            return;
+        }
+        frame_start = std::max<int64_t>(frame_start, partition_start);
+        frame_end = std::min<int64_t>(frame_end, partition_end);
+        if constexpr (Data::nullable) {
+            this->set_null_if_need();
+            const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
+            for (int i = frame_end - 1; i >= frame_start; i--) {
+                if (!nullable_column->is_null_at(i)) {
+                    this->set_value(columns, i);
+                    return;
+                }
+            }
+        } else {
+            this->set_value(columns, frame_end - 1);
+        }
+    }
+
+    void add(int64_t row, const IColumn** columns) {
+        if constexpr (Data::nullable) {
+            this->set_null_if_need();
+            const auto* nullable_column = assert_cast<const ColumnNullable*>(columns[0]);
+            if (nullable_column->is_null_at(row)) {
+                return;
+            }
+        }
+        this->set_value(columns, row);
+    }
+
+    static const char* name() { return "last_non_null_value"; }
+};
+
 template <typename Data>
 class WindowFunctionData final
         : public IAggregateFunctionDataHelper<Data, WindowFunctionData<Data>> {
@@ -416,10 +509,7 @@ static IAggregateFunction* create_function_single_value(const String& name,
 
     assert_arity_at_most<3>(name, argument_types);
 
-    auto type = argument_types[0].get();
-    if (type->is_nullable()) {
-        type = assert_cast<const DataTypeNullable*>(type)->get_nested_type().get();
-    }
+    auto type = remove_nullable(argument_types[0]);
     WhichDataType which(*type);
 
 #define DISPATCH(TYPE)                        \
@@ -455,6 +545,16 @@ AggregateFunctionPtr create_aggregate_function_first(const std::string& name,
                                          is_copy>(name, argument_types, parameters));
 }
 
+template <bool is_nullable, bool is_copy>
+AggregateFunctionPtr create_aggregate_function_first_non_null_value(const std::string& name,
+                                                                    const DataTypes& argument_types,
+                                                                    const Array& parameters,
+                                                                    bool result_is_nullable) {
+    return AggregateFunctionPtr(
+            create_function_single_value<WindowFunctionData, WindowFunctionFirstNonNullData,
+                                         is_nullable, is_copy>(name, argument_types, parameters));
+}
+
 template <bool is_nullable, bool is_copy>
 AggregateFunctionPtr create_aggregate_function_last(const std::string& name,
                                                     const DataTypes& argument_types,
@@ -465,4 +565,14 @@ AggregateFunctionPtr create_aggregate_function_last(const std::string& name,
                                          is_copy>(name, argument_types, parameters));
 }
 
+template <bool is_nullable, bool is_copy>
+AggregateFunctionPtr create_aggregate_function_last_non_null_value(const std::string& name,
+                                                                   const DataTypes& argument_types,
+                                                                   const Array& parameters,
+                                                                   bool result_is_nullable) {
+    return AggregateFunctionPtr(
+            create_function_single_value<WindowFunctionData, WindowFunctionLastNonNullData,
+                                         is_nullable, is_copy>(name, argument_types, parameters));
+}
+
 } // namespace doris::vectorized


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org