You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2022/06/30 10:03:27 UTC

[doris] 12/13: [hotfix](dev-1.0.1) BE prevent core by nullable not suit in hash join node

This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch dev-1.0.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 9177ccd30af9321a5ade5845fe7c87c71c9a7ab9
Author: lihaopeng <li...@baidu.com>
AuthorDate: Sun Jun 26 23:55:21 2022 +0800

    [hotfix](dev-1.0.1) BE prevent core by nullable not suit in hash join node
---
 be/src/vec/exec/join/vhash_join_node.cpp           | 145 ++++++++++++---------
 be/src/vec/exec/vaggregation_node.cpp              |  29 +++--
 be/src/vec/exec/vaggregation_node.h                |   1 +
 be/src/vec/exprs/vslot_ref.h                       |   1 -
 be/src/vec/functions/functions_logical.cpp         |  18 ++-
 .../org/apache/doris/planner/AggregationNode.java  |   1 +
 gensrc/thrift/PlanNodes.thrift                     |   1 +
 7 files changed, 118 insertions(+), 78 deletions(-)

diff --git a/be/src/vec/exec/join/vhash_join_node.cpp b/be/src/vec/exec/join/vhash_join_node.cpp
index c8d7f6e77c..596bef712b 100644
--- a/be/src/vec/exec/join/vhash_join_node.cpp
+++ b/be/src/vec/exec/join/vhash_join_node.cpp
@@ -28,12 +28,11 @@
 
 namespace doris::vectorized {
 
-std::variant<std::false_type, std::true_type>
-static inline make_bool_variant(bool condition) {
+std::variant<std::false_type, std::true_type> static inline make_bool_variant(bool condition) {
     if (condition) {
-        return std::true_type{};
+        return std::true_type {};
     } else {
-        return std::false_type{};
+        return std::false_type {};
     }
 }
 
@@ -178,7 +177,7 @@ struct ProcessHashTableProbe {
     // output build side result column
     template <bool have_other_join_conjunct = false>
     void build_side_output_column(MutableColumns& mcol, int column_offset, int column_length,
-            const std::vector<bool>& output_slot_flags, int size) {
+                                  const std::vector<bool>& output_slot_flags, int size) {
         constexpr auto is_semi_anti_join = JoinOpType::value == TJoinOp::RIGHT_ANTI_JOIN ||
                                            JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN ||
                                            JoinOpType::value == TJoinOp::LEFT_ANTI_JOIN ||
@@ -192,8 +191,8 @@ struct ProcessHashTableProbe {
                 for (int i = 0; i < column_length; i++) {
                     auto& column = *_build_blocks[0].get_by_position(i).column;
                     if (output_slot_flags[i]) {
-                        mcol[i + column_offset]->insert_indices_from(column, _build_block_rows.data(),
-                                                                     _build_block_rows.data() + size);
+                        mcol[i + column_offset]->insert_indices_from(
+                                column, _build_block_rows.data(), _build_block_rows.data() + size);
                     } else {
                         mcol[i + column_offset]->resize(size);
                     }
@@ -205,14 +204,19 @@ struct ProcessHashTableProbe {
                             if constexpr (probe_all) {
                                 if (_build_block_offsets[j] == -1) {
                                     DCHECK(mcol[i + column_offset]->is_nullable());
-                                    assert_cast<ColumnNullable *>(
-                                            mcol[i + column_offset].get())->insert_join_null_data();
+                                    assert_cast<ColumnNullable*>(mcol[i + column_offset].get())
+                                            ->insert_join_null_data();
                                 } else {
-                                    auto &column = *_build_blocks[_build_block_offsets[j]].get_by_position(i).column;
-                                    mcol[i + column_offset]->insert_from(column, _build_block_rows[j]);
+                                    auto& column = *_build_blocks[_build_block_offsets[j]]
+                                                            .get_by_position(i)
+                                                            .column;
+                                    mcol[i + column_offset]->insert_from(column,
+                                                                         _build_block_rows[j]);
                                 }
                             } else {
-                                auto &column = *_build_blocks[_build_block_offsets[j]].get_by_position(i).column;
+                                auto& column = *_build_blocks[_build_block_offsets[j]]
+                                                        .get_by_position(i)
+                                                        .column;
                                 mcol[i + column_offset]->insert_from(column, _build_block_rows[j]);
                             }
                         }
@@ -225,7 +229,8 @@ struct ProcessHashTableProbe {
     }
 
     // output probe side result column
-    void probe_side_output_column(MutableColumns& mcol, const std::vector<bool>& output_slot_flags, int size) {
+    void probe_side_output_column(MutableColumns& mcol, const std::vector<bool>& output_slot_flags,
+                                  int size) {
         for (int i = 0; i < output_slot_flags.size(); ++i) {
             if (output_slot_flags[i]) {
                 auto& column = _probe_block.get_by_position(i).column;
@@ -244,8 +249,8 @@ struct ProcessHashTableProbe {
         using KeyGetter = typename HashTableContext::State;
         using Mapped = typename HashTableContext::Mapped;
 
-        int right_col_idx = _join_node->_is_right_semi_anti ? 0 :
-                _join_node->_left_table_data_types.size();
+        int right_col_idx =
+                _join_node->_is_right_semi_anti ? 0 : _join_node->_left_table_data_types.size();
         int right_col_len = _join_node->_right_table_data_types.size();
 
         KeyGetter key_getter(_probe_raw_ptrs, _join_node->_probe_key_sz, nullptr);
@@ -258,15 +263,15 @@ struct ProcessHashTableProbe {
         memset(_items_counts.data(), 0, sizeof(uint32_t) * _probe_rows);
 
         constexpr auto need_to_set_visited = JoinOpType::value == TJoinOp::RIGHT_ANTI_JOIN ||
-                                       JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN ||
-                                       JoinOpType::value == TJoinOp::RIGHT_OUTER_JOIN ||
-                                       JoinOpType::value == TJoinOp::FULL_OUTER_JOIN;
+                                             JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN ||
+                                             JoinOpType::value == TJoinOp::RIGHT_OUTER_JOIN ||
+                                             JoinOpType::value == TJoinOp::FULL_OUTER_JOIN;
 
         constexpr auto is_right_semi_anti_join = JoinOpType::value == TJoinOp::RIGHT_ANTI_JOIN ||
-                                            JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN;
+                                                 JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN;
 
         constexpr auto probe_all = JoinOpType::value == TJoinOp::LEFT_OUTER_JOIN ||
-                                     JoinOpType::value == TJoinOp::FULL_OUTER_JOIN;
+                                   JoinOpType::value == TJoinOp::FULL_OUTER_JOIN;
 
         {
             SCOPED_TIMER(_search_hashtable_timer);
@@ -279,9 +284,11 @@ struct ProcessHashTableProbe {
                 }
                 int last_offset = current_offset;
                 auto find_result = (*null_map)[_probe_index]
-                                    ? decltype(key_getter.find_key(hash_table_ctx.hash_table, _probe_index,
-                                                                    _arena)) {nullptr, false}
-                                    : key_getter.find_key(hash_table_ctx.hash_table, _probe_index, _arena);
+                                           ? decltype(key_getter.find_key(hash_table_ctx.hash_table,
+                                                                          _probe_index,
+                                                                          _arena)) {nullptr, false}
+                                           : key_getter.find_key(hash_table_ctx.hash_table,
+                                                                 _probe_index, _arena);
 
                 if constexpr (JoinOpType::value == TJoinOp::LEFT_ANTI_JOIN) {
                     if (!find_result.is_found()) {
@@ -297,8 +304,7 @@ struct ProcessHashTableProbe {
                         // TODO: Iterators are currently considered to be a heavy operation and have a certain impact on performance.
                         // We should rethink whether to use this iterator mode in the future. Now just opt the one row case
                         if (mapped.get_row_count() == 1) {
-                            if constexpr (need_to_set_visited)
-                                mapped.visited = true;
+                            if constexpr (need_to_set_visited) mapped.visited = true;
 
                             if constexpr (!is_right_semi_anti_join) {
                                 _build_block_offsets[current_offset] = mapped.block_offset;
@@ -308,7 +314,8 @@ struct ProcessHashTableProbe {
                         } else {
                             // prefetch is more useful while matching to multiple rows
                             if (_probe_index + 2 < _probe_rows)
-                                key_getter.prefetch(hash_table_ctx.hash_table, _probe_index + 2, _arena);
+                                key_getter.prefetch(hash_table_ctx.hash_table, _probe_index + 2,
+                                                    _arena);
 
                             for (auto it = mapped.begin(); it.ok(); ++it) {
                                 if constexpr (!is_right_semi_anti_join) {
@@ -321,8 +328,7 @@ struct ProcessHashTableProbe {
                                     }
                                     ++current_offset;
                                 }
-                                if constexpr (need_to_set_visited)
-                                    it->visited = true;
+                                if constexpr (need_to_set_visited) it->visited = true;
                             }
                         }
                     } else {
@@ -345,10 +351,11 @@ struct ProcessHashTableProbe {
         {
             SCOPED_TIMER(_build_side_output_timer);
             build_side_output_column(mcol, right_col_idx, right_col_len,
-                    _join_node->_right_output_slot_flags, current_offset);
+                                     _join_node->_right_output_slot_flags, current_offset);
         }
 
-        if constexpr (JoinOpType::value != TJoinOp::RIGHT_SEMI_JOIN && JoinOpType::value != TJoinOp::RIGHT_ANTI_JOIN) {
+        if constexpr (JoinOpType::value != TJoinOp::RIGHT_SEMI_JOIN &&
+                      JoinOpType::value != TJoinOp::RIGHT_ANTI_JOIN) {
             SCOPED_TIMER(_probe_side_output_timer);
             probe_side_output_column(mcol, _join_node->_left_output_slot_flags, current_offset);
         }
@@ -453,7 +460,7 @@ struct ProcessHashTableProbe {
         {
             SCOPED_TIMER(_build_side_output_timer);
             build_side_output_column<true>(mcol, right_col_idx, right_col_len,
-                    _join_node->_right_output_slot_flags, current_offset);
+                                           _join_node->_right_output_slot_flags, current_offset);
         }
         {
             SCOPED_TIMER(_probe_side_output_timer);
@@ -528,9 +535,11 @@ struct ProcessHashTableProbe {
                 auto new_filter_column = ColumnVector<UInt8>::create();
                 auto& filter_map = new_filter_column->get_data();
 
-                if (!column->empty()) filter_map.emplace_back(column->get_bool(0) && visited_map[0]);
+                if (!column->empty())
+                    filter_map.emplace_back(column->get_bool(0) && visited_map[0]);
                 for (int i = 1; i < column->size(); ++i) {
-                    if ((visited_map[i] && column->get_bool(i)) || (same_to_prev[i] && filter_map[i - 1])) {
+                    if ((visited_map[i] && column->get_bool(i)) ||
+                        (same_to_prev[i] && filter_map[i - 1])) {
                         filter_map.push_back(true);
                         filter_map[i - 1] = !same_to_prev[i] && filter_map[i - 1];
                     } else {
@@ -561,7 +570,8 @@ struct ProcessHashTableProbe {
                 output_block->clear();
             } else {
                 if constexpr (JoinOpType::value == TJoinOp::LEFT_SEMI_JOIN ||
-                          JoinOpType::value == TJoinOp::LEFT_ANTI_JOIN) orig_columns = right_col_idx;
+                              JoinOpType::value == TJoinOp::LEFT_ANTI_JOIN)
+                    orig_columns = right_col_idx;
                 Block::filter_block(output_block, result_column_id, orig_columns);
             }
         }
@@ -597,11 +607,9 @@ struct ProcessHashTableProbe {
             auto& mapped = iter->get_second();
             for (auto it = mapped.begin(); it.ok(); ++it) {
                 if constexpr (JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN) {
-                    if (it->visited)
-                        insert_from_hash_table(it->block_offset, it->row_num);
+                    if (it->visited) insert_from_hash_table(it->block_offset, it->row_num);
                 } else {
-                    if (!it->visited)
-                        insert_from_hash_table(it->block_offset, it->row_num);
+                    if (!it->visited) insert_from_hash_table(it->block_offset, it->row_num);
                 }
             }
         }
@@ -612,7 +620,7 @@ struct ProcessHashTableProbe {
                       JoinOpType::value == TJoinOp::FULL_OUTER_JOIN) {
             for (int i = 0; i < right_col_idx; ++i) {
                 for (int j = 0; j < block_size; ++j) {
-                    assert_cast<ColumnNullable *>(mcol[i].get())->insert_join_null_data();
+                    assert_cast<ColumnNullable*>(mcol[i].get())->insert_join_null_data();
                 }
             }
         }
@@ -674,15 +682,15 @@ HashJoinNode::~HashJoinNode() = default;
 
 void HashJoinNode::init_join_op() {
     switch (_join_op) {
-#define M(NAME)                                                                                      \
-        case TJoinOp::NAME:                                                                          \
-            _join_op_variants.emplace<std::integral_constant<TJoinOp::type, TJoinOp::NAME>>();       \
-            break;
+#define M(NAME)                                                                            \
+    case TJoinOp::NAME:                                                                    \
+        _join_op_variants.emplace<std::integral_constant<TJoinOp::type, TJoinOp::NAME>>(); \
+        break;
         APPLY_FOR_JOINOP_VARIANTS(M);
 #undef M
-        default:
-            //do nothing
-            break;
+    default:
+        //do nothing
+        break;
     }
 }
 
@@ -741,8 +749,8 @@ Status HashJoinNode::init(const TPlanNode& tnode, RuntimeState* state) {
     }
 
     for (const auto& filter_desc : _runtime_filter_descs) {
-        RETURN_IF_ERROR(state->runtime_filter_mgr()->regist_filter(RuntimeFilterRole::PRODUCER,
-                                                                   filter_desc, state->query_options()));
+        RETURN_IF_ERROR(state->runtime_filter_mgr()->regist_filter(
+                RuntimeFilterRole::PRODUCER, filter_desc, state->query_options()));
     }
 
     // init left/right output slots flags, only column of slot_id in _hash_output_slot_ids need
@@ -753,9 +761,10 @@ Status HashJoinNode::init(const TPlanNode& tnode, RuntimeState* state) {
     auto init_output_slots_flags = [this](auto& tuple_descs, auto& output_slot_flags) {
         for (const auto& tuple_desc : tuple_descs) {
             for (const auto& slot_desc : tuple_desc->slots()) {
-                output_slot_flags.emplace_back(_hash_output_slot_ids.empty() ||
-                                               std::find(_hash_output_slot_ids.begin(), _hash_output_slot_ids.end(),
-                                                         slot_desc->id()) != _hash_output_slot_ids.end());
+                output_slot_flags.emplace_back(
+                        _hash_output_slot_ids.empty() ||
+                        std::find(_hash_output_slot_ids.begin(), _hash_output_slot_ids.end(),
+                                  slot_desc->id()) != _hash_output_slot_ids.end());
             }
         }
     };
@@ -916,8 +925,8 @@ Status HashJoinNode::get_next(RuntimeState* state, Block* output_block, bool* eo
                             LOG(FATAL) << "FATAL: uninited hash table";
                         }
                     }
-            }, _hash_table_variants,
-               _join_op_variants,
+                },
+                _hash_table_variants, _join_op_variants,
                 make_bool_variant(_have_other_join_conjunct),
                 make_bool_variant(_probe_ignore_null));
     } else if (_probe_eos) {
@@ -935,8 +944,7 @@ Status HashJoinNode::get_next(RuntimeState* state, Block* output_block, bool* eo
                             LOG(FATAL) << "FATAL: uninited hash table";
                         }
                     },
-                    _hash_table_variants,
-                    _join_op_variants);
+                    _hash_table_variants, _join_op_variants);
         } else {
             *eos = true;
             return Status::OK();
@@ -1022,6 +1030,7 @@ Status HashJoinNode::open(RuntimeState* state) {
 
     RETURN_IF_ERROR(VExpr::open(_build_expr_ctxs, state));
     RETURN_IF_ERROR(VExpr::open(_probe_expr_ctxs, state));
+    RETURN_IF_ERROR(VExpr::open(_output_expr_ctxs, state));
     if (_vother_join_conjunct_ptr) {
         RETURN_IF_ERROR((*_vother_join_conjunct_ptr)->open(state));
     }
@@ -1051,7 +1060,9 @@ Status HashJoinNode::_hash_table_build(RuntimeState* state) {
         _mem_used += block.allocated_bytes();
         RETURN_IF_LIMIT_EXCEEDED(state, "Hash join, while getting next from the child 1.");
 
-        if (block.rows() != 0) { mutable_block.merge(block); }
+        if (block.rows() != 0) {
+            mutable_block.merge(block);
+        }
 
         // make one block for each 4 gigabytes
         constexpr static auto BUILD_BLOCK_MAX_SIZE = 4 * 1024UL * 1024UL * 1024UL;
@@ -1100,7 +1111,7 @@ Status HashJoinNode::_extract_build_join_column(Block& block, NullMap& null_map,
         // TODO: opt the column is const
         block.get_by_position(result_col_id).column =
                 block.get_by_position(result_col_id).column->convert_to_full_column_if_const();
-        
+
         if (_is_null_safe_eq_join[i]) {
             raw_ptrs[i] = block.get_by_position(result_col_id).column.get();
         } else {
@@ -1209,7 +1220,7 @@ Status HashJoinNode::_process_build_block(RuntimeState* state, Block& block, uin
                 if constexpr (!std::is_same_v<HashTableCtxType, std::monostate>) {
 #define CALL_BUILD_FUNCTION(HAS_NULL, BUILD_UNIQUE)                                           \
     ProcessHashTableBuild<HashTableCtxType, HAS_NULL, BUILD_UNIQUE> hash_table_build_process( \
-            rows, block, raw_ptrs, this, state->batch_size(), offset);                       \
+            rows, block, raw_ptrs, this, state->batch_size(), offset);                        \
     st = hash_table_build_process(arg, &null_map_val, has_runtime_filter);
                     if (std::pair {has_null, _build_unique} == std::pair {true, true}) {
                         CALL_BUILD_FUNCTION(true, true);
@@ -1332,13 +1343,25 @@ Status HashJoinNode::_build_output_block(Block* origin_block, Block* output_bloc
                          : MutableBlock(VectorizedUtils::create_empty_columnswithtypename(
                                    _output_row_desc));
     auto rows = origin_block->rows();
+    // TODO: After FE plan support same nullable of output expr and origin block and mutable column
+    // we should repalce `insert_column_datas` by `insert_range_from`
+
+    auto insert_column_datas = [](auto& to, const auto& from, size_t rows) {
+        if (to->is_nullable() && !from.is_nullable()) {
+            auto& null_column = reinterpret_cast<ColumnNullable&>(*to);
+            null_column.get_nested_column().insert_range_from(from, 0, rows);
+            null_column.get_null_map_column().get_data().resize_fill(rows, 0);
+        } else {
+            to->insert_range_from(from, 0, rows);
+        }
+    };
     if (rows != 0) {
         auto& mutable_columns = mutable_block.mutable_columns();
         if (_output_expr_ctxs.empty()) {
             DCHECK(mutable_columns.size() == origin_block->columns());
             for (int i = 0; i < mutable_columns.size(); ++i) {
-                mutable_columns[i]->insert_range_from(*origin_block->get_by_position(i).column, 0,
-                                                      rows);
+                insert_column_datas(mutable_columns[i], *origin_block->get_by_position(i).column,
+                                    rows);
             }
         } else {
             DCHECK(mutable_columns.size() == _output_expr_ctxs.size());
@@ -1347,7 +1370,7 @@ Status HashJoinNode::_build_output_block(Block* origin_block, Block* output_bloc
                 RETURN_IF_ERROR(_output_expr_ctxs[i]->execute(origin_block, &result_column_id));
                 auto column_ptr = origin_block->get_by_position(result_column_id)
                                           .column->convert_to_full_column_if_const();
-                mutable_columns[i]->insert_range_from(*column_ptr, 0, rows);
+                insert_column_datas(mutable_columns[i], *column_ptr, rows);
             }
         }
 
diff --git a/be/src/vec/exec/vaggregation_node.cpp b/be/src/vec/exec/vaggregation_node.cpp
index 8230d4d697..66f6037715 100644
--- a/be/src/vec/exec/vaggregation_node.cpp
+++ b/be/src/vec/exec/vaggregation_node.cpp
@@ -119,6 +119,11 @@ Status AggregationNode::init(const TPlanNode& tnode, RuntimeState* state) {
     const auto& agg_functions = tnode.agg_node.aggregate_functions;
     _is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(),
                             [](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; });
+
+    // only corner case in query : https://github.com/apache/doris/issues/10302
+    // agg will do merge in update stage. in this case the merge function should use probe expr (slotref) column
+    // id to do merge like update function
+    _is_update_stage = tnode.agg_node.is_update_stage;
     return Status::OK();
 }
 
@@ -526,12 +531,14 @@ Status AggregationNode::_merge_without_key(Block* block) {
     std::unique_ptr<char[]> deserialize_buffer(new char[_total_size_of_aggregate_states]);
     int rows = block->rows();
     for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
-        DCHECK(_aggregate_evaluators[i]->input_exprs_ctxs().size() == 1 &&
-               _aggregate_evaluators[i]->input_exprs_ctxs()[0]->root()->is_slot_ref());
-        int col_id =
-                ((VSlotRef*)_aggregate_evaluators[i]->input_exprs_ctxs()[0]->root())->column_id();
         if (_aggregate_evaluators[i]->is_merge()) {
-            auto column = block->get_by_position(col_id).column;
+            auto column =
+                    block->get_by_position(_is_update_stage ? ((VSlotRef*)_aggregate_evaluators[i]
+                                                                       ->input_exprs_ctxs()[0]
+                                                                       ->root())
+                                                                      ->column_id()
+                                                            : i)
+                            .column;
             if (column->is_nullable()) {
                 column = ((ColumnNullable*)column.get())->get_nested_column_ptr();
             }
@@ -1050,12 +1057,14 @@ Status AggregationNode::_merge_with_serialized_key(Block* block) {
     std::unique_ptr<char[]> deserialize_buffer(new char[_total_size_of_aggregate_states]);
 
     for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
-        DCHECK(_aggregate_evaluators[i]->input_exprs_ctxs().size() == 1 &&
-               _aggregate_evaluators[i]->input_exprs_ctxs()[0]->root()->is_slot_ref());
-        int col_id =
-                ((VSlotRef*)_aggregate_evaluators[i]->input_exprs_ctxs()[0]->root())->column_id();
         if (_aggregate_evaluators[i]->is_merge()) {
-            auto column = block->get_by_position(col_id).column;
+            auto column =
+                    block->get_by_position(_is_update_stage ? ((VSlotRef*)_aggregate_evaluators[i]
+                                                                       ->input_exprs_ctxs()[0]
+                                                                       ->root())
+                                                                      ->column_id()
+                                                            : i + key_size)
+                            .column;
             if (column->is_nullable()) {
                 column = ((ColumnNullable*)column.get())->get_nested_column_ptr();
             }
diff --git a/be/src/vec/exec/vaggregation_node.h b/be/src/vec/exec/vaggregation_node.h
index f020b90a6e..d46f856f8b 100644
--- a/be/src/vec/exec/vaggregation_node.h
+++ b/be/src/vec/exec/vaggregation_node.h
@@ -433,6 +433,7 @@ private:
 
     bool _needs_finalize;
     bool _is_merge;
+    bool _is_update_stage;
     std::unique_ptr<MemPool> _mem_pool;
 
     size_t _align_aggregate_states = 1;
diff --git a/be/src/vec/exprs/vslot_ref.h b/be/src/vec/exprs/vslot_ref.h
index 1bc78a4c5c..c00a018b3b 100644
--- a/be/src/vec/exprs/vslot_ref.h
+++ b/be/src/vec/exprs/vslot_ref.h
@@ -38,7 +38,6 @@ public:
     virtual const std::string& expr_name() const override;
     virtual std::string debug_string() const override;
     virtual bool is_constant() const override { return false; }
-
     const int column_id() const { return _column_id; }
 
 private:
diff --git a/be/src/vec/functions/functions_logical.cpp b/be/src/vec/functions/functions_logical.cpp
index 8ef554fd9b..0c8bded880 100644
--- a/be/src/vec/functions/functions_logical.cpp
+++ b/be/src/vec/functions/functions_logical.cpp
@@ -204,7 +204,7 @@ class AssociativeGenericApplierImpl {
 public:
     /// Remembers the last N columns from `in`.
     AssociativeGenericApplierImpl(const ColumnRawPtrs& in)
-            : val_getter{ValueGetterBuilder::build(in[in.size() - N])}, next{in} {}
+            : val_getter {ValueGetterBuilder::build(in[in.size() - N])}, next {in} {}
 
     /// Returns a combination of values in the i-th row of all columns stored in the constructor.
     inline ResultValueType apply(const size_t i) const {
@@ -227,7 +227,7 @@ class AssociativeGenericApplierImpl<Op, 1> {
 public:
     /// Remembers the last N columns from `in`.
     AssociativeGenericApplierImpl(const ColumnRawPtrs& in)
-            : val_getter{ValueGetterBuilder::build(in[in.size() - 1])} {}
+            : val_getter {ValueGetterBuilder::build(in[in.size() - 1])} {}
 
     inline ResultValueType apply(const size_t i) const { return val_getter(i); }
 
@@ -449,14 +449,20 @@ Status FunctionAnyArityLogical<Impl, Name>::execute_impl(FunctionContext* contex
                                                          size_t result_index,
                                                          size_t input_rows_count) {
     ColumnRawPtrs args_in;
-    for (const auto arg_index : arguments)
-        args_in.push_back(block.get_by_position(arg_index).column.get());
+    bool is_nullable = false;
+    for (const auto arg_index : arguments) {
+        auto& data = block.get_by_position(arg_index);
+        args_in.push_back(data.column.get());
+        is_nullable |= data.column->is_nullable();
+    }
 
     auto& result_info = block.get_by_position(result_index);
-    if (result_info.type->is_nullable())
+    if (is_nullable) {
+        result_info.type = make_nullable(result_info.type);
         execute_for_ternary_logic_impl<Impl>(std::move(args_in), result_info, input_rows_count);
-    else
+    } else {
         basic_execute_impl<Impl>(std::move(args_in), result_info, input_rows_count);
+    }
     return Status::OK();
 }
 
diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java
index 65ee4d6d8a..e7dee2651c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java
@@ -277,6 +277,7 @@ public class AggregationNode extends PlanNode {
                   aggInfo.getIntermediateTupleId().asInt(),
                   aggInfo.getOutputTupleId().asInt(), needsFinalize);
         msg.agg_node.setUseStreamingPreaggregation(useStreamingPreagg);
+        msg.agg_node.setIsUpdateStage(!aggInfo.isMerge());
         List<Expr> groupingExprs = aggInfo.getGroupingExprs();
         if (groupingExprs != null) {
             msg.agg_node.setGroupingExprs(Expr.treesToThrift(groupingExprs));
diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift
index 34be03f8b6..dfe90968cb 100644
--- a/gensrc/thrift/PlanNodes.thrift
+++ b/gensrc/thrift/PlanNodes.thrift
@@ -464,6 +464,7 @@ struct TAggregationNode {
   // rows have been aggregated, and this node is not an intermediate node.
   5: required bool need_finalize
   6: optional bool use_streaming_preaggregation
+  7: optional bool is_update_stage
 }
 
 struct TRepeatNode {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org