You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2022/06/30 10:03:27 UTC
[doris] 12/13: [hotfix](dev-1.0.1) BE prevent core by nullable not suit in hash join node
This is an automated email from the ASF dual-hosted git repository.
morningman pushed a commit to branch dev-1.0.1
in repository https://gitbox.apache.org/repos/asf/doris.git
commit 9177ccd30af9321a5ade5845fe7c87c71c9a7ab9
Author: lihaopeng <li...@baidu.com>
AuthorDate: Sun Jun 26 23:55:21 2022 +0800
[hotfix](dev-1.0.1) BE prevent core by nullable not suit in hash join node
---
be/src/vec/exec/join/vhash_join_node.cpp | 145 ++++++++++++---------
be/src/vec/exec/vaggregation_node.cpp | 29 +++--
be/src/vec/exec/vaggregation_node.h | 1 +
be/src/vec/exprs/vslot_ref.h | 1 -
be/src/vec/functions/functions_logical.cpp | 18 ++-
.../org/apache/doris/planner/AggregationNode.java | 1 +
gensrc/thrift/PlanNodes.thrift | 1 +
7 files changed, 118 insertions(+), 78 deletions(-)
diff --git a/be/src/vec/exec/join/vhash_join_node.cpp b/be/src/vec/exec/join/vhash_join_node.cpp
index c8d7f6e77c..596bef712b 100644
--- a/be/src/vec/exec/join/vhash_join_node.cpp
+++ b/be/src/vec/exec/join/vhash_join_node.cpp
@@ -28,12 +28,11 @@
namespace doris::vectorized {
-std::variant<std::false_type, std::true_type>
-static inline make_bool_variant(bool condition) {
+std::variant<std::false_type, std::true_type> static inline make_bool_variant(bool condition) {
if (condition) {
- return std::true_type{};
+ return std::true_type {};
} else {
- return std::false_type{};
+ return std::false_type {};
}
}
@@ -178,7 +177,7 @@ struct ProcessHashTableProbe {
// output build side result column
template <bool have_other_join_conjunct = false>
void build_side_output_column(MutableColumns& mcol, int column_offset, int column_length,
- const std::vector<bool>& output_slot_flags, int size) {
+ const std::vector<bool>& output_slot_flags, int size) {
constexpr auto is_semi_anti_join = JoinOpType::value == TJoinOp::RIGHT_ANTI_JOIN ||
JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN ||
JoinOpType::value == TJoinOp::LEFT_ANTI_JOIN ||
@@ -192,8 +191,8 @@ struct ProcessHashTableProbe {
for (int i = 0; i < column_length; i++) {
auto& column = *_build_blocks[0].get_by_position(i).column;
if (output_slot_flags[i]) {
- mcol[i + column_offset]->insert_indices_from(column, _build_block_rows.data(),
- _build_block_rows.data() + size);
+ mcol[i + column_offset]->insert_indices_from(
+ column, _build_block_rows.data(), _build_block_rows.data() + size);
} else {
mcol[i + column_offset]->resize(size);
}
@@ -205,14 +204,19 @@ struct ProcessHashTableProbe {
if constexpr (probe_all) {
if (_build_block_offsets[j] == -1) {
DCHECK(mcol[i + column_offset]->is_nullable());
- assert_cast<ColumnNullable *>(
- mcol[i + column_offset].get())->insert_join_null_data();
+ assert_cast<ColumnNullable*>(mcol[i + column_offset].get())
+ ->insert_join_null_data();
} else {
- auto &column = *_build_blocks[_build_block_offsets[j]].get_by_position(i).column;
- mcol[i + column_offset]->insert_from(column, _build_block_rows[j]);
+ auto& column = *_build_blocks[_build_block_offsets[j]]
+ .get_by_position(i)
+ .column;
+ mcol[i + column_offset]->insert_from(column,
+ _build_block_rows[j]);
}
} else {
- auto &column = *_build_blocks[_build_block_offsets[j]].get_by_position(i).column;
+ auto& column = *_build_blocks[_build_block_offsets[j]]
+ .get_by_position(i)
+ .column;
mcol[i + column_offset]->insert_from(column, _build_block_rows[j]);
}
}
@@ -225,7 +229,8 @@ struct ProcessHashTableProbe {
}
// output probe side result column
- void probe_side_output_column(MutableColumns& mcol, const std::vector<bool>& output_slot_flags, int size) {
+ void probe_side_output_column(MutableColumns& mcol, const std::vector<bool>& output_slot_flags,
+ int size) {
for (int i = 0; i < output_slot_flags.size(); ++i) {
if (output_slot_flags[i]) {
auto& column = _probe_block.get_by_position(i).column;
@@ -244,8 +249,8 @@ struct ProcessHashTableProbe {
using KeyGetter = typename HashTableContext::State;
using Mapped = typename HashTableContext::Mapped;
- int right_col_idx = _join_node->_is_right_semi_anti ? 0 :
- _join_node->_left_table_data_types.size();
+ int right_col_idx =
+ _join_node->_is_right_semi_anti ? 0 : _join_node->_left_table_data_types.size();
int right_col_len = _join_node->_right_table_data_types.size();
KeyGetter key_getter(_probe_raw_ptrs, _join_node->_probe_key_sz, nullptr);
@@ -258,15 +263,15 @@ struct ProcessHashTableProbe {
memset(_items_counts.data(), 0, sizeof(uint32_t) * _probe_rows);
constexpr auto need_to_set_visited = JoinOpType::value == TJoinOp::RIGHT_ANTI_JOIN ||
- JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN ||
- JoinOpType::value == TJoinOp::RIGHT_OUTER_JOIN ||
- JoinOpType::value == TJoinOp::FULL_OUTER_JOIN;
+ JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN ||
+ JoinOpType::value == TJoinOp::RIGHT_OUTER_JOIN ||
+ JoinOpType::value == TJoinOp::FULL_OUTER_JOIN;
constexpr auto is_right_semi_anti_join = JoinOpType::value == TJoinOp::RIGHT_ANTI_JOIN ||
- JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN;
+ JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN;
constexpr auto probe_all = JoinOpType::value == TJoinOp::LEFT_OUTER_JOIN ||
- JoinOpType::value == TJoinOp::FULL_OUTER_JOIN;
+ JoinOpType::value == TJoinOp::FULL_OUTER_JOIN;
{
SCOPED_TIMER(_search_hashtable_timer);
@@ -279,9 +284,11 @@ struct ProcessHashTableProbe {
}
int last_offset = current_offset;
auto find_result = (*null_map)[_probe_index]
- ? decltype(key_getter.find_key(hash_table_ctx.hash_table, _probe_index,
- _arena)) {nullptr, false}
- : key_getter.find_key(hash_table_ctx.hash_table, _probe_index, _arena);
+ ? decltype(key_getter.find_key(hash_table_ctx.hash_table,
+ _probe_index,
+ _arena)) {nullptr, false}
+ : key_getter.find_key(hash_table_ctx.hash_table,
+ _probe_index, _arena);
if constexpr (JoinOpType::value == TJoinOp::LEFT_ANTI_JOIN) {
if (!find_result.is_found()) {
@@ -297,8 +304,7 @@ struct ProcessHashTableProbe {
// TODO: Iterators are currently considered to be a heavy operation and have a certain impact on performance.
// We should rethink whether to use this iterator mode in the future. Now just opt the one row case
if (mapped.get_row_count() == 1) {
- if constexpr (need_to_set_visited)
- mapped.visited = true;
+ if constexpr (need_to_set_visited) mapped.visited = true;
if constexpr (!is_right_semi_anti_join) {
_build_block_offsets[current_offset] = mapped.block_offset;
@@ -308,7 +314,8 @@ struct ProcessHashTableProbe {
} else {
// prefetch is more useful while matching to multiple rows
if (_probe_index + 2 < _probe_rows)
- key_getter.prefetch(hash_table_ctx.hash_table, _probe_index + 2, _arena);
+ key_getter.prefetch(hash_table_ctx.hash_table, _probe_index + 2,
+ _arena);
for (auto it = mapped.begin(); it.ok(); ++it) {
if constexpr (!is_right_semi_anti_join) {
@@ -321,8 +328,7 @@ struct ProcessHashTableProbe {
}
++current_offset;
}
- if constexpr (need_to_set_visited)
- it->visited = true;
+ if constexpr (need_to_set_visited) it->visited = true;
}
}
} else {
@@ -345,10 +351,11 @@ struct ProcessHashTableProbe {
{
SCOPED_TIMER(_build_side_output_timer);
build_side_output_column(mcol, right_col_idx, right_col_len,
- _join_node->_right_output_slot_flags, current_offset);
+ _join_node->_right_output_slot_flags, current_offset);
}
- if constexpr (JoinOpType::value != TJoinOp::RIGHT_SEMI_JOIN && JoinOpType::value != TJoinOp::RIGHT_ANTI_JOIN) {
+ if constexpr (JoinOpType::value != TJoinOp::RIGHT_SEMI_JOIN &&
+ JoinOpType::value != TJoinOp::RIGHT_ANTI_JOIN) {
SCOPED_TIMER(_probe_side_output_timer);
probe_side_output_column(mcol, _join_node->_left_output_slot_flags, current_offset);
}
@@ -453,7 +460,7 @@ struct ProcessHashTableProbe {
{
SCOPED_TIMER(_build_side_output_timer);
build_side_output_column<true>(mcol, right_col_idx, right_col_len,
- _join_node->_right_output_slot_flags, current_offset);
+ _join_node->_right_output_slot_flags, current_offset);
}
{
SCOPED_TIMER(_probe_side_output_timer);
@@ -528,9 +535,11 @@ struct ProcessHashTableProbe {
auto new_filter_column = ColumnVector<UInt8>::create();
auto& filter_map = new_filter_column->get_data();
- if (!column->empty()) filter_map.emplace_back(column->get_bool(0) && visited_map[0]);
+ if (!column->empty())
+ filter_map.emplace_back(column->get_bool(0) && visited_map[0]);
for (int i = 1; i < column->size(); ++i) {
- if ((visited_map[i] && column->get_bool(i)) || (same_to_prev[i] && filter_map[i - 1])) {
+ if ((visited_map[i] && column->get_bool(i)) ||
+ (same_to_prev[i] && filter_map[i - 1])) {
filter_map.push_back(true);
filter_map[i - 1] = !same_to_prev[i] && filter_map[i - 1];
} else {
@@ -561,7 +570,8 @@ struct ProcessHashTableProbe {
output_block->clear();
} else {
if constexpr (JoinOpType::value == TJoinOp::LEFT_SEMI_JOIN ||
- JoinOpType::value == TJoinOp::LEFT_ANTI_JOIN) orig_columns = right_col_idx;
+ JoinOpType::value == TJoinOp::LEFT_ANTI_JOIN)
+ orig_columns = right_col_idx;
Block::filter_block(output_block, result_column_id, orig_columns);
}
}
@@ -597,11 +607,9 @@ struct ProcessHashTableProbe {
auto& mapped = iter->get_second();
for (auto it = mapped.begin(); it.ok(); ++it) {
if constexpr (JoinOpType::value == TJoinOp::RIGHT_SEMI_JOIN) {
- if (it->visited)
- insert_from_hash_table(it->block_offset, it->row_num);
+ if (it->visited) insert_from_hash_table(it->block_offset, it->row_num);
} else {
- if (!it->visited)
- insert_from_hash_table(it->block_offset, it->row_num);
+ if (!it->visited) insert_from_hash_table(it->block_offset, it->row_num);
}
}
}
@@ -612,7 +620,7 @@ struct ProcessHashTableProbe {
JoinOpType::value == TJoinOp::FULL_OUTER_JOIN) {
for (int i = 0; i < right_col_idx; ++i) {
for (int j = 0; j < block_size; ++j) {
- assert_cast<ColumnNullable *>(mcol[i].get())->insert_join_null_data();
+ assert_cast<ColumnNullable*>(mcol[i].get())->insert_join_null_data();
}
}
}
@@ -674,15 +682,15 @@ HashJoinNode::~HashJoinNode() = default;
void HashJoinNode::init_join_op() {
switch (_join_op) {
-#define M(NAME) \
- case TJoinOp::NAME: \
- _join_op_variants.emplace<std::integral_constant<TJoinOp::type, TJoinOp::NAME>>(); \
- break;
+#define M(NAME) \
+ case TJoinOp::NAME: \
+ _join_op_variants.emplace<std::integral_constant<TJoinOp::type, TJoinOp::NAME>>(); \
+ break;
APPLY_FOR_JOINOP_VARIANTS(M);
#undef M
- default:
- //do nothing
- break;
+ default:
+ //do nothing
+ break;
}
}
@@ -741,8 +749,8 @@ Status HashJoinNode::init(const TPlanNode& tnode, RuntimeState* state) {
}
for (const auto& filter_desc : _runtime_filter_descs) {
- RETURN_IF_ERROR(state->runtime_filter_mgr()->regist_filter(RuntimeFilterRole::PRODUCER,
- filter_desc, state->query_options()));
+ RETURN_IF_ERROR(state->runtime_filter_mgr()->regist_filter(
+ RuntimeFilterRole::PRODUCER, filter_desc, state->query_options()));
}
// init left/right output slots flags, only column of slot_id in _hash_output_slot_ids need
@@ -753,9 +761,10 @@ Status HashJoinNode::init(const TPlanNode& tnode, RuntimeState* state) {
auto init_output_slots_flags = [this](auto& tuple_descs, auto& output_slot_flags) {
for (const auto& tuple_desc : tuple_descs) {
for (const auto& slot_desc : tuple_desc->slots()) {
- output_slot_flags.emplace_back(_hash_output_slot_ids.empty() ||
- std::find(_hash_output_slot_ids.begin(), _hash_output_slot_ids.end(),
- slot_desc->id()) != _hash_output_slot_ids.end());
+ output_slot_flags.emplace_back(
+ _hash_output_slot_ids.empty() ||
+ std::find(_hash_output_slot_ids.begin(), _hash_output_slot_ids.end(),
+ slot_desc->id()) != _hash_output_slot_ids.end());
}
}
};
@@ -916,8 +925,8 @@ Status HashJoinNode::get_next(RuntimeState* state, Block* output_block, bool* eo
LOG(FATAL) << "FATAL: uninited hash table";
}
}
- }, _hash_table_variants,
- _join_op_variants,
+ },
+ _hash_table_variants, _join_op_variants,
make_bool_variant(_have_other_join_conjunct),
make_bool_variant(_probe_ignore_null));
} else if (_probe_eos) {
@@ -935,8 +944,7 @@ Status HashJoinNode::get_next(RuntimeState* state, Block* output_block, bool* eo
LOG(FATAL) << "FATAL: uninited hash table";
}
},
- _hash_table_variants,
- _join_op_variants);
+ _hash_table_variants, _join_op_variants);
} else {
*eos = true;
return Status::OK();
@@ -1022,6 +1030,7 @@ Status HashJoinNode::open(RuntimeState* state) {
RETURN_IF_ERROR(VExpr::open(_build_expr_ctxs, state));
RETURN_IF_ERROR(VExpr::open(_probe_expr_ctxs, state));
+ RETURN_IF_ERROR(VExpr::open(_output_expr_ctxs, state));
if (_vother_join_conjunct_ptr) {
RETURN_IF_ERROR((*_vother_join_conjunct_ptr)->open(state));
}
@@ -1051,7 +1060,9 @@ Status HashJoinNode::_hash_table_build(RuntimeState* state) {
_mem_used += block.allocated_bytes();
RETURN_IF_LIMIT_EXCEEDED(state, "Hash join, while getting next from the child 1.");
- if (block.rows() != 0) { mutable_block.merge(block); }
+ if (block.rows() != 0) {
+ mutable_block.merge(block);
+ }
// make one block for each 4 gigabytes
constexpr static auto BUILD_BLOCK_MAX_SIZE = 4 * 1024UL * 1024UL * 1024UL;
@@ -1100,7 +1111,7 @@ Status HashJoinNode::_extract_build_join_column(Block& block, NullMap& null_map,
// TODO: opt the column is const
block.get_by_position(result_col_id).column =
block.get_by_position(result_col_id).column->convert_to_full_column_if_const();
-
+
if (_is_null_safe_eq_join[i]) {
raw_ptrs[i] = block.get_by_position(result_col_id).column.get();
} else {
@@ -1209,7 +1220,7 @@ Status HashJoinNode::_process_build_block(RuntimeState* state, Block& block, uin
if constexpr (!std::is_same_v<HashTableCtxType, std::monostate>) {
#define CALL_BUILD_FUNCTION(HAS_NULL, BUILD_UNIQUE) \
ProcessHashTableBuild<HashTableCtxType, HAS_NULL, BUILD_UNIQUE> hash_table_build_process( \
- rows, block, raw_ptrs, this, state->batch_size(), offset); \
+ rows, block, raw_ptrs, this, state->batch_size(), offset); \
st = hash_table_build_process(arg, &null_map_val, has_runtime_filter);
if (std::pair {has_null, _build_unique} == std::pair {true, true}) {
CALL_BUILD_FUNCTION(true, true);
@@ -1332,13 +1343,25 @@ Status HashJoinNode::_build_output_block(Block* origin_block, Block* output_bloc
: MutableBlock(VectorizedUtils::create_empty_columnswithtypename(
_output_row_desc));
auto rows = origin_block->rows();
+ // TODO: After FE plan support same nullable of output expr and origin block and mutable column
+ // we should repalce `insert_column_datas` by `insert_range_from`
+
+ auto insert_column_datas = [](auto& to, const auto& from, size_t rows) {
+ if (to->is_nullable() && !from.is_nullable()) {
+ auto& null_column = reinterpret_cast<ColumnNullable&>(*to);
+ null_column.get_nested_column().insert_range_from(from, 0, rows);
+ null_column.get_null_map_column().get_data().resize_fill(rows, 0);
+ } else {
+ to->insert_range_from(from, 0, rows);
+ }
+ };
if (rows != 0) {
auto& mutable_columns = mutable_block.mutable_columns();
if (_output_expr_ctxs.empty()) {
DCHECK(mutable_columns.size() == origin_block->columns());
for (int i = 0; i < mutable_columns.size(); ++i) {
- mutable_columns[i]->insert_range_from(*origin_block->get_by_position(i).column, 0,
- rows);
+ insert_column_datas(mutable_columns[i], *origin_block->get_by_position(i).column,
+ rows);
}
} else {
DCHECK(mutable_columns.size() == _output_expr_ctxs.size());
@@ -1347,7 +1370,7 @@ Status HashJoinNode::_build_output_block(Block* origin_block, Block* output_bloc
RETURN_IF_ERROR(_output_expr_ctxs[i]->execute(origin_block, &result_column_id));
auto column_ptr = origin_block->get_by_position(result_column_id)
.column->convert_to_full_column_if_const();
- mutable_columns[i]->insert_range_from(*column_ptr, 0, rows);
+ insert_column_datas(mutable_columns[i], *column_ptr, rows);
}
}
diff --git a/be/src/vec/exec/vaggregation_node.cpp b/be/src/vec/exec/vaggregation_node.cpp
index 8230d4d697..66f6037715 100644
--- a/be/src/vec/exec/vaggregation_node.cpp
+++ b/be/src/vec/exec/vaggregation_node.cpp
@@ -119,6 +119,11 @@ Status AggregationNode::init(const TPlanNode& tnode, RuntimeState* state) {
const auto& agg_functions = tnode.agg_node.aggregate_functions;
_is_merge = std::any_of(agg_functions.cbegin(), agg_functions.cend(),
[](const auto& e) { return e.nodes[0].agg_expr.is_merge_agg; });
+
+ // only corner case in query : https://github.com/apache/doris/issues/10302
+ // agg will do merge in update stage. in this case the merge function should use probe expr (slotref) column
+ // id to do merge like update function
+ _is_update_stage = tnode.agg_node.is_update_stage;
return Status::OK();
}
@@ -526,12 +531,14 @@ Status AggregationNode::_merge_without_key(Block* block) {
std::unique_ptr<char[]> deserialize_buffer(new char[_total_size_of_aggregate_states]);
int rows = block->rows();
for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
- DCHECK(_aggregate_evaluators[i]->input_exprs_ctxs().size() == 1 &&
- _aggregate_evaluators[i]->input_exprs_ctxs()[0]->root()->is_slot_ref());
- int col_id =
- ((VSlotRef*)_aggregate_evaluators[i]->input_exprs_ctxs()[0]->root())->column_id();
if (_aggregate_evaluators[i]->is_merge()) {
- auto column = block->get_by_position(col_id).column;
+ auto column =
+ block->get_by_position(_is_update_stage ? ((VSlotRef*)_aggregate_evaluators[i]
+ ->input_exprs_ctxs()[0]
+ ->root())
+ ->column_id()
+ : i)
+ .column;
if (column->is_nullable()) {
column = ((ColumnNullable*)column.get())->get_nested_column_ptr();
}
@@ -1050,12 +1057,14 @@ Status AggregationNode::_merge_with_serialized_key(Block* block) {
std::unique_ptr<char[]> deserialize_buffer(new char[_total_size_of_aggregate_states]);
for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
- DCHECK(_aggregate_evaluators[i]->input_exprs_ctxs().size() == 1 &&
- _aggregate_evaluators[i]->input_exprs_ctxs()[0]->root()->is_slot_ref());
- int col_id =
- ((VSlotRef*)_aggregate_evaluators[i]->input_exprs_ctxs()[0]->root())->column_id();
if (_aggregate_evaluators[i]->is_merge()) {
- auto column = block->get_by_position(col_id).column;
+ auto column =
+ block->get_by_position(_is_update_stage ? ((VSlotRef*)_aggregate_evaluators[i]
+ ->input_exprs_ctxs()[0]
+ ->root())
+ ->column_id()
+ : i + key_size)
+ .column;
if (column->is_nullable()) {
column = ((ColumnNullable*)column.get())->get_nested_column_ptr();
}
diff --git a/be/src/vec/exec/vaggregation_node.h b/be/src/vec/exec/vaggregation_node.h
index f020b90a6e..d46f856f8b 100644
--- a/be/src/vec/exec/vaggregation_node.h
+++ b/be/src/vec/exec/vaggregation_node.h
@@ -433,6 +433,7 @@ private:
bool _needs_finalize;
bool _is_merge;
+ bool _is_update_stage;
std::unique_ptr<MemPool> _mem_pool;
size_t _align_aggregate_states = 1;
diff --git a/be/src/vec/exprs/vslot_ref.h b/be/src/vec/exprs/vslot_ref.h
index 1bc78a4c5c..c00a018b3b 100644
--- a/be/src/vec/exprs/vslot_ref.h
+++ b/be/src/vec/exprs/vslot_ref.h
@@ -38,7 +38,6 @@ public:
virtual const std::string& expr_name() const override;
virtual std::string debug_string() const override;
virtual bool is_constant() const override { return false; }
-
const int column_id() const { return _column_id; }
private:
diff --git a/be/src/vec/functions/functions_logical.cpp b/be/src/vec/functions/functions_logical.cpp
index 8ef554fd9b..0c8bded880 100644
--- a/be/src/vec/functions/functions_logical.cpp
+++ b/be/src/vec/functions/functions_logical.cpp
@@ -204,7 +204,7 @@ class AssociativeGenericApplierImpl {
public:
/// Remembers the last N columns from `in`.
AssociativeGenericApplierImpl(const ColumnRawPtrs& in)
- : val_getter{ValueGetterBuilder::build(in[in.size() - N])}, next{in} {}
+ : val_getter {ValueGetterBuilder::build(in[in.size() - N])}, next {in} {}
/// Returns a combination of values in the i-th row of all columns stored in the constructor.
inline ResultValueType apply(const size_t i) const {
@@ -227,7 +227,7 @@ class AssociativeGenericApplierImpl<Op, 1> {
public:
/// Remembers the last N columns from `in`.
AssociativeGenericApplierImpl(const ColumnRawPtrs& in)
- : val_getter{ValueGetterBuilder::build(in[in.size() - 1])} {}
+ : val_getter {ValueGetterBuilder::build(in[in.size() - 1])} {}
inline ResultValueType apply(const size_t i) const { return val_getter(i); }
@@ -449,14 +449,20 @@ Status FunctionAnyArityLogical<Impl, Name>::execute_impl(FunctionContext* contex
size_t result_index,
size_t input_rows_count) {
ColumnRawPtrs args_in;
- for (const auto arg_index : arguments)
- args_in.push_back(block.get_by_position(arg_index).column.get());
+ bool is_nullable = false;
+ for (const auto arg_index : arguments) {
+ auto& data = block.get_by_position(arg_index);
+ args_in.push_back(data.column.get());
+ is_nullable |= data.column->is_nullable();
+ }
auto& result_info = block.get_by_position(result_index);
- if (result_info.type->is_nullable())
+ if (is_nullable) {
+ result_info.type = make_nullable(result_info.type);
execute_for_ternary_logic_impl<Impl>(std::move(args_in), result_info, input_rows_count);
- else
+ } else {
basic_execute_impl<Impl>(std::move(args_in), result_info, input_rows_count);
+ }
return Status::OK();
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java
index 65ee4d6d8a..e7dee2651c 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java
@@ -277,6 +277,7 @@ public class AggregationNode extends PlanNode {
aggInfo.getIntermediateTupleId().asInt(),
aggInfo.getOutputTupleId().asInt(), needsFinalize);
msg.agg_node.setUseStreamingPreaggregation(useStreamingPreagg);
+ msg.agg_node.setIsUpdateStage(!aggInfo.isMerge());
List<Expr> groupingExprs = aggInfo.getGroupingExprs();
if (groupingExprs != null) {
msg.agg_node.setGroupingExprs(Expr.treesToThrift(groupingExprs));
diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift
index 34be03f8b6..dfe90968cb 100644
--- a/gensrc/thrift/PlanNodes.thrift
+++ b/gensrc/thrift/PlanNodes.thrift
@@ -464,6 +464,7 @@ struct TAggregationNode {
// rows have been aggregated, and this node is not an intermediate node.
5: required bool need_finalize
6: optional bool use_streaming_preaggregation
+ 7: optional bool is_update_stage
}
struct TRepeatNode {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org