You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by yi...@apache.org on 2022/07/28 01:13:04 UTC
[doris] branch master updated: [Vectorized] Support order by aggregate function (#11187)
This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 0b1d06bfd6 [Vectorized] Support order by aggregate function (#11187)
0b1d06bfd6 is described below
commit 0b1d06bfd66ec8ac62d94fa4e3a8578f54516f7e
Author: HappenLee <ha...@hotmail.com>
AuthorDate: Thu Jul 28 09:12:58 2022 +0800
[Vectorized] Support order by aggregate function (#11187)
Co-authored-by: lihaopeng <li...@baidu.com>
---
.../aggregate_function_simple_factory.cpp | 2 -
.../aggregate_function_sort.cpp | 79 +++-------------------
.../aggregate_functions/aggregate_function_sort.h | 73 +++++++++-----------
be/src/vec/exec/vaggregation_node.cpp | 4 +-
be/src/vec/exec/vanalytic_eval_node.cpp | 2 +-
be/src/vec/exprs/vectorized_agg_fn.cpp | 36 ++++++++--
be/src/vec/exprs/vectorized_agg_fn.h | 10 ++-
fe/fe-core/src/main/cup/sql_parser.cup | 4 +-
.../org/apache/doris/analysis/AggregateInfo.java | 15 ++--
.../apache/doris/analysis/FunctionCallExpr.java | 56 +++++++++++++--
.../java/org/apache/doris/analysis/SelectList.java | 7 ++
.../apache/doris/catalog/AggregateFunction.java | 2 +
.../org/apache/doris/planner/AggregationNode.java | 14 ++++
gensrc/thrift/PlanNodes.thrift | 1 +
.../data/query/group_concat/test_group_concat.out | 20 ++++++
.../query/group_concat/test_group_concat.groovy | 8 +++
16 files changed, 192 insertions(+), 141 deletions(-)
diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index c2169eb223..73779f8ffa 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -84,8 +84,6 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_window_lead_lag(instance);
register_aggregate_function_HLL_union_agg(instance);
register_aggregate_function_percentile_approx(instance);
-
- register_aggregate_function_combinator_sort(instance);
});
return instance;
}
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sort.cpp b/be/src/vec/aggregate_functions/aggregate_function_sort.cpp
index fbdb16df4f..b0566b829f 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_sort.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_sort.cpp
@@ -19,84 +19,21 @@
#include "vec/aggregate_functions/aggregate_function_combinator.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
-#include "vec/aggregate_functions/helpers.h"
#include "vec/common/typeid_cast.h"
#include "vec/data_types/data_type_nullable.h"
-#include "vec/utils/template_helpers.hpp"
namespace doris::vectorized {
-class AggregateFunctionCombinatorSort final : public IAggregateFunctionCombinator {
-private:
- int _sort_column_number;
-
-public:
- AggregateFunctionCombinatorSort(int sort_column_number)
- : _sort_column_number(sort_column_number) {}
-
- String get_name() const override { return "Sort"; }
-
- DataTypes transform_arguments(const DataTypes& arguments) const override {
- if (arguments.size() < _sort_column_number + 2) {
- LOG(FATAL) << "Incorrect number of arguments for aggregate function with Sort, "
- << arguments.size() << " less than " << _sort_column_number + 2;
- }
-
- DataTypes nested_types;
- nested_types.assign(arguments.begin(), arguments.end() - 1 - _sort_column_number);
- return nested_types;
+AggregateFunctionPtr transform_to_sort_agg_function(const AggregateFunctionPtr& nested_function,
+ const DataTypes& arguments,
+ const SortDescription& sort_desc) {
+ DCHECK(nested_function != nullptr);
+ if (nested_function == nullptr) {
+ return nullptr;
}
- template <int sort_column_number>
- struct Reducer {
- static void run(AggregateFunctionPtr& function, const AggregateFunctionPtr& nested_function,
- const DataTypes& arguments) {
- function = std::make_shared<
- AggregateFunctionSort<sort_column_number, AggregateFunctionSortData>>(
- nested_function, arguments);
- }
- };
-
- AggregateFunctionPtr transform_aggregate_function(
- const AggregateFunctionPtr& nested_function, const DataTypes& arguments,
- const Array& params, const bool result_is_nullable) const override {
- DCHECK(nested_function != nullptr);
- if (nested_function == nullptr) {
- return nullptr;
- }
-
- AggregateFunctionPtr function = nullptr;
- constexpr_int_match<1, 3, Reducer>::run(_sort_column_number, function, nested_function,
- arguments);
-
- return function;
- }
+ return std::make_shared<AggregateFunctionSort<AggregateFunctionSortData>>(nested_function,
+ arguments, sort_desc);
};
-const std::string SORT_FUNCTION_PREFIX = "sort_";
-
-void register_aggregate_function_combinator_sort(AggregateFunctionSimpleFactory& factory) {
- AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
- const Array& params, const bool result_is_nullable) {
- int sort_column_number = std::stoi(name.substr(SORT_FUNCTION_PREFIX.size(), 2));
- auto nested_function_name = name.substr(SORT_FUNCTION_PREFIX.size() + 2);
-
- auto function_combinator =
- std::make_shared<AggregateFunctionCombinatorSort>(sort_column_number);
-
- auto transform_arguments = function_combinator->transform_arguments(types);
-
- auto nested_function =
- factory.get(nested_function_name, transform_arguments, params, result_is_nullable);
- return function_combinator->transform_aggregate_function(nested_function, types, params,
- result_is_nullable);
- };
-
- for (char c = '1'; c <= '3'; c++) {
- factory.register_distinct_function_combinator(creator, SORT_FUNCTION_PREFIX + c + "_",
- false);
- factory.register_distinct_function_combinator(creator, SORT_FUNCTION_PREFIX + c + "_",
- true);
- }
-}
} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sort.h b/be/src/vec/aggregate_functions/aggregate_function_sort.h
index 5cad555c37..cd16d29770 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_sort.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_sort.h
@@ -18,6 +18,7 @@
#pragma once
#include <string>
+#include <utility>
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/key_holder_helpers.h"
@@ -35,10 +36,17 @@
namespace doris::vectorized {
-template <int sort_column_size>
struct AggregateFunctionSortData {
+ const SortDescription sort_desc;
+ Block block;
+
+ // The construct only support the template compiler, useless
+ AggregateFunctionSortData() {};
+ AggregateFunctionSortData(SortDescription sort_desc, const Block& block)
+ : sort_desc(std::move(sort_desc)), block(block.clone_empty()) {}
+
void merge(const AggregateFunctionSortData& rhs) {
- if (block.is_empty_column()) {
+ if (block.rows() == 0) {
block = rhs.block;
} else {
for (size_t i = 0; i < block.columns(); i++) {
@@ -78,45 +86,18 @@ struct AggregateFunctionSortData {
}
}
- void sort() {
- size_t sort_desc_idx = block.columns() - sort_column_size - 1;
- StringRef desc_str =
- block.get_by_position(sort_desc_idx).column->assume_mutable()->get_data_at(0);
- DCHECK(sort_column_size == desc_str.size);
-
- SortDescription sort_description(sort_column_size);
- for (size_t i = 0; i < sort_column_size; i++) {
- sort_description[i].column_number = sort_desc_idx + 1 + i;
- sort_description[i].direction = (desc_str.data[i] == '0') ? 1 : -1;
- sort_description[i].nulls_direction = sort_description[i].direction;
- }
-
- sort_block(block, sort_description, block.rows());
- }
-
- void try_init(const DataTypes& _arguments) {
- if (!block.is_empty_column()) {
- return;
- }
-
- for (auto type : _arguments) {
- block.insert({type, ""});
- }
- }
-
- Block block;
+ void sort() { sort_block(block, sort_desc, block.rows()); }
};
-template <int sort_column_size, template <int> typename Data>
+template <typename Data>
class AggregateFunctionSort
- : public IAggregateFunctionDataHelper<Data<sort_column_size>,
- AggregateFunctionSort<sort_column_size, Data>> {
- using DataReal = Data<sort_column_size>;
-
+ : public IAggregateFunctionDataHelper<Data, AggregateFunctionSort<Data>> {
private:
- static constexpr auto prefix_size = sizeof(DataReal);
+ static constexpr auto prefix_size = sizeof(Data);
AggregateFunctionPtr _nested_func;
DataTypes _arguments;
+ const SortDescription& _sort_desc;
+ Block _block;
AggregateDataPtr get_nested_place(AggregateDataPtr __restrict place) const noexcept {
return place + prefix_size;
@@ -127,15 +108,20 @@ private:
}
public:
- AggregateFunctionSort(AggregateFunctionPtr nested_func, const DataTypes& arguments)
- : IAggregateFunctionDataHelper<DataReal, AggregateFunctionSort>(
+ AggregateFunctionSort(const AggregateFunctionPtr& nested_func, const DataTypes& arguments,
+ const SortDescription& sort_desc)
+ : IAggregateFunctionDataHelper<Data, AggregateFunctionSort>(
arguments, nested_func->get_parameters()),
_nested_func(nested_func),
- _arguments(arguments) {}
+ _arguments(arguments),
+ _sort_desc(sort_desc) {
+ for (const auto& type : _arguments) {
+ _block.insert({type, ""});
+ }
+ }
void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
Arena* arena) const override {
- this->data(place).try_init(_arguments);
this->data(place).add(columns, _arguments.size(), row_num);
}
@@ -159,7 +145,7 @@ public:
this->data(place).sort();
ColumnRawPtrs arguments_nested;
- for (int i = 0; i < _arguments.size() - 1 - sort_column_size; i++) {
+ for (int i = 0; i < _arguments.size() - _sort_desc.size(); i++) {
arguments_nested.emplace_back(
this->data(place).block.get_by_position(i).column.get());
}
@@ -176,12 +162,12 @@ public:
size_t align_of_data() const override { return _nested_func->align_of_data(); }
void create(AggregateDataPtr __restrict place) const override {
- new (place) DataReal;
+ new (place) Data(_sort_desc, _block);
_nested_func->create(get_nested_place(place));
}
void destroy(AggregateDataPtr __restrict place) const noexcept override {
- this->data(place).~DataReal();
+ this->data(place).~Data();
_nested_func->destroy(get_nested_place(place));
}
@@ -190,4 +176,7 @@ public:
DataTypePtr get_return_type() const override { return _nested_func->get_return_type(); }
};
+AggregateFunctionPtr transform_to_sort_agg_function(const AggregateFunctionPtr& nested_function,
+ const DataTypes& arguments,
+ const SortDescription& sort_desc);
} // namespace doris::vectorized
diff --git a/be/src/vec/exec/vaggregation_node.cpp b/be/src/vec/exec/vaggregation_node.cpp
index e7008957eb..d4325bc17e 100644
--- a/be/src/vec/exec/vaggregation_node.cpp
+++ b/be/src/vec/exec/vaggregation_node.cpp
@@ -116,8 +116,8 @@ Status AggregationNode::init(const TPlanNode& tnode, RuntimeState* state) {
_aggregate_evaluators.reserve(tnode.agg_node.aggregate_functions.size());
for (int i = 0; i < tnode.agg_node.aggregate_functions.size(); ++i) {
AggFnEvaluator* evaluator = nullptr;
- RETURN_IF_ERROR(
- AggFnEvaluator::create(_pool, tnode.agg_node.aggregate_functions[i], &evaluator));
+ RETURN_IF_ERROR(AggFnEvaluator::create(_pool, tnode.agg_node.aggregate_functions[i],
+ tnode.agg_node.agg_sort_infos[i], &evaluator));
_aggregate_evaluators.push_back(evaluator);
}
diff --git a/be/src/vec/exec/vanalytic_eval_node.cpp b/be/src/vec/exec/vanalytic_eval_node.cpp
index 3906062cc4..dde7cb8453 100644
--- a/be/src/vec/exec/vanalytic_eval_node.cpp
+++ b/be/src/vec/exec/vanalytic_eval_node.cpp
@@ -127,7 +127,7 @@ Status VAnalyticEvalNode::init(const TPlanNode& tnode, RuntimeState* state) {
AggFnEvaluator* evaluator = nullptr;
RETURN_IF_ERROR(
- AggFnEvaluator::create(_pool, analytic_node.analytic_functions[i], &evaluator));
+ AggFnEvaluator::create(_pool, analytic_node.analytic_functions[i], {}, &evaluator));
_agg_functions.emplace_back(evaluator);
for (size_t j = 0; j < _agg_expr_ctxs[i].size(); ++j) {
_agg_intput_columns[i][j] = _agg_expr_ctxs[i][j]->root()->data_type()->create_column();
diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp b/be/src/vec/exprs/vectorized_agg_fn.cpp
index 427b90ec2f..527cb40e18 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.cpp
+++ b/be/src/vec/exprs/vectorized_agg_fn.cpp
@@ -23,11 +23,13 @@
#include "vec/aggregate_functions/aggregate_function_java_udaf.h"
#include "vec/aggregate_functions/aggregate_function_rpc.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/aggregate_function_sort.h"
#include "vec/columns/column_nullable.h"
#include "vec/core/materialize_block.h"
#include "vec/data_types/data_type_factory.hpp"
#include "vec/data_types/data_type_nullable.h"
#include "vec/exprs/vexpr.h"
+
namespace doris::vectorized {
AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
@@ -46,12 +48,14 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
_data_type = DataTypeFactory::instance().create_data_type(_return_type, nullable);
auto& param_types = desc.agg_expr.param_types;
- for (auto raw_type : param_types) {
- _argument_types.push_back(DataTypeFactory::instance().create_data_type(raw_type));
+ for (int i = 0; i < param_types.size(); i++) {
+ _argument_types_with_sort.push_back(
+ DataTypeFactory::instance().create_data_type(param_types[i]));
}
}
-Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, AggFnEvaluator** result) {
+Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info,
+ AggFnEvaluator** result) {
*result = pool->add(new AggFnEvaluator(desc.nodes[0]));
auto& agg_fn_evaluator = *result;
int node_idx = 0;
@@ -63,6 +67,22 @@ Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, AggFnEvaluato
VExpr::create_tree_from_thrift(pool, desc.nodes, nullptr, &node_idx, &expr, &ctx));
agg_fn_evaluator->_input_exprs_ctxs.push_back(ctx);
}
+
+ auto sort_size = sort_info.ordering_exprs.size();
+ auto real_arguments_size = agg_fn_evaluator->_argument_types_with_sort.size() - sort_size;
+ // Child arguments conatins [real arguments, order by arguments], we pass the arguments
+ // to the order by functions
+ for (int i = 0; i < sort_size; ++i) {
+ agg_fn_evaluator->_sort_description.emplace_back(real_arguments_size + i,
+ sort_info.is_asc_order[i] == true,
+ sort_info.nulls_first[i] == true);
+ }
+
+ // Pass the real arguments to get functions
+ for (int i = 0; i < real_arguments_size; ++i) {
+ agg_fn_evaluator->_real_argument_types.emplace_back(
+ agg_fn_evaluator->_argument_types_with_sort[i]);
+ }
return Status::OK();
}
@@ -87,20 +107,24 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc, M
if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) {
#ifdef LIBJVM
- _function = AggregateJavaUdaf::create(_fn, _argument_types, {}, _data_type);
+ _function = AggregateJavaUdaf::create(_fn, _real_argument_types, {}, _data_type);
#else
return Status::InternalError("Java UDAF is disabled since no libjvm is found!");
#endif
} else if (_fn.binary_type == TFunctionBinaryType::RPC) {
- _function = AggregateRpcUdaf::create(_fn, _argument_types, {}, _data_type);
+ _function = AggregateRpcUdaf::create(_fn, _real_argument_types, {}, _data_type);
} else {
_function = AggregateFunctionSimpleFactory::instance().get(
- _fn.name.function_name, _argument_types, {}, _data_type->is_nullable());
+ _fn.name.function_name, _real_argument_types, {}, _data_type->is_nullable());
}
if (_function == nullptr) {
return Status::InternalError("Agg Function {} is not implemented", _fn.name.function_name);
}
+ if (!_sort_description.empty()) {
+ _function = transform_to_sort_agg_function(_function, _argument_types_with_sort,
+ _sort_description);
+ }
_expr_name = fmt::format("{}({})", _fn.name.function_name, child_expr_name);
return Status::OK();
}
diff --git a/be/src/vec/exprs/vectorized_agg_fn.h b/be/src/vec/exprs/vectorized_agg_fn.h
index a541257487..52098f0d8f 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.h
+++ b/be/src/vec/exprs/vectorized_agg_fn.h
@@ -20,6 +20,7 @@
#include "util/runtime_profile.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/core/block.h"
+#include "vec/core/sort_description.h"
#include "vec/data_types/data_type.h"
#include "vec/exprs/vexpr_context.h"
@@ -29,7 +30,8 @@ class SlotDescriptor;
namespace vectorized {
class AggFnEvaluator {
public:
- static Status create(ObjectPool* pool, const TExpr& desc, AggFnEvaluator** result);
+ static Status create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info,
+ AggFnEvaluator** result);
Status prepare(RuntimeState* state, const RowDescriptor& desc, MemPool* pool,
const SlotDescriptor* intermediate_slot_desc,
@@ -80,7 +82,9 @@ private:
void _calc_argment_columns(Block* block);
- DataTypes _argument_types;
+ DataTypes _argument_types_with_sort;
+ DataTypes _real_argument_types;
+
const TypeDescriptor _return_type;
const SlotDescriptor* _intermediate_slot_desc;
@@ -93,6 +97,8 @@ private:
// input context
std::vector<VExprContext*> _input_exprs_ctxs;
+ SortDescription _sort_description;
+
DataTypePtr _data_type;
AggregateFunctionPtr _function;
diff --git a/fe/fe-core/src/main/cup/sql_parser.cup b/fe/fe-core/src/main/cup/sql_parser.cup
index 556ab0785f..1042783da3 100644
--- a/fe/fe-core/src/main/cup/sql_parser.cup
+++ b/fe/fe-core/src/main/cup/sql_parser.cup
@@ -4950,8 +4950,8 @@ non_pred_expr ::=
{: RESULT = new FunctionCallExpr(fn_name, exprs); :}
//| function_name:fn_name LPAREN RPAREN
//{: RESULT = new FunctionCallExpr(fn_name, new ArrayList<Expr>()); :}
- //| function_name:fn_name LPAREN function_params:params RPAREN
- //{: RESULT = new FunctionCallExpr(fn_name, params); :}
+ | function_name:fn_name LPAREN function_params:params order_by_clause:o RPAREN
+ {: RESULT = new FunctionCallExpr(fn_name, params, o); :}
| analytic_expr:e
{: RESULT = e; :}
/* Since "IF" is a keyword, need to special case this function */
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
index a0152a8e8a..79e01e0cae 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
@@ -496,6 +496,8 @@ public final class AggregateInfo extends AggregateInfoBase {
FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall(
inputExpr, Lists.newArrayList(aggExprParam), inputExpr.getFnParams().exprs());
aggExpr.analyzeNoThrow(analyzer);
+ // do not need analyze in merge stage, just do mark for BE get right function
+ aggExpr.setOrderByElements(inputExpr.getOrderByElements());
aggExprs.add(aggExpr);
}
@@ -621,7 +623,6 @@ public final class AggregateInfo extends AggregateInfoBase {
}
Preconditions.checkState(
secondPhaseAggExprs.size() == aggregateExprs.size() + distinctAggExprs.size());
-
for (FunctionCallExpr aggExpr : secondPhaseAggExprs) {
aggExpr.analyzeNoThrow(analyzer);
Preconditions.checkState(aggExpr.isAggregateFunction());
@@ -649,18 +650,16 @@ public final class AggregateInfo extends AggregateInfoBase {
int numDistinctParams = 0;
if (!isMultiDistinct) {
numDistinctParams = distinctAggExprs.get(0).getChildren().size();
- // If we are counting distinct params of group_concat, we cannot include the custom
- // separator since it is not a distinct param.
- if (distinctAggExprs.get(0).getFnName().getFunction().equalsIgnoreCase("group_concat")
- && numDistinctParams == 2) {
- --numDistinctParams;
- }
} else {
for (int i = 0; i < distinctAggExprs.size(); i++) {
numDistinctParams += distinctAggExprs.get(i).getChildren().size();
}
}
-
+ // If we are counting distinct params of group_concat, we cannot include the custom
+ // separator since it is not a distinct param.
+ if (distinctAggExprs.get(0).getFnName().getFunction().equalsIgnoreCase("group_concat")) {
+ numDistinctParams = 1;
+ }
int numOrigGroupingExprs = inputAggInfo.getGroupingExprs().size() - numDistinctParams;
Preconditions.checkState(
slotDescs.size() == numOrigGroupingExprs + distinctAggExprs.size()
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
index ee65f35469..6038d5e3f1 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
@@ -90,9 +90,10 @@ public class FunctionCallExpr extends Expr {
// private BuiltinAggregateFunction.Operator aggOp;
private FunctionParams fnParams;
- // represent original parament from aggregate function
private FunctionParams aggFnParams;
+ private List<OrderByElement> orderByElements = Lists.newArrayList();
+
// check analytic function
private boolean isAnalyticFnCall = false;
// check table function
@@ -155,6 +156,27 @@ public class FunctionCallExpr extends Expr {
this(fnName, params, false);
}
+ public FunctionCallExpr(
+ FunctionName fnName, FunctionParams params, List<OrderByElement> orderByElements) throws AnalysisException {
+ this(fnName, params, false);
+ this.orderByElements = orderByElements;
+ if (!orderByElements.isEmpty()) {
+ if (!VectorizedUtil.isVectorized()) {
+ throw new AnalysisException(
+ "ORDER BY for arguments only support in vec exec engine");
+ } else if (!AggregateFunction.SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET.contains(
+ fnName.getFunction().toLowerCase())) {
+ throw new AnalysisException(
+ "ORDER BY not support for the function:" + fnName.getFunction().toLowerCase());
+ } else if (params.isDistinct()) {
+ throw new AnalysisException(
+ "ORDER BY not support for the distinct, support in the furture:"
+ + fnName.getFunction().toLowerCase());
+ }
+ }
+ setChildren();
+ }
+
private FunctionCallExpr(
FunctionName fnName, FunctionParams params, boolean isMergeAggFn) {
super();
@@ -187,6 +209,7 @@ public class FunctionCallExpr extends Expr {
protected FunctionCallExpr(FunctionCallExpr other) {
super(other);
fnName = other.fnName;
+ orderByElements = other.orderByElements;
isAnalyticFnCall = other.isAnalyticFnCall;
// aggOp = other.aggOp;
// fnParams = other.fnParams;
@@ -289,6 +312,8 @@ public class FunctionCallExpr extends Expr {
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt"))) {
result.add("\'***\'");
+ } else if (orderByElements.size() > 0 && i == len - orderByElements.size()) {
+ result.add("ORDER BY " + children.get(i).toSql());
} else {
result.add(children.get(i).toSql());
}
@@ -503,7 +528,7 @@ public class FunctionCallExpr extends Expr {
}
if (fnName.getFunction().equalsIgnoreCase("group_concat")) {
- if (children.size() > 2 || children.isEmpty()) {
+ if (children.size() - orderByElements.size() > 2 || children.isEmpty()) {
throw new AnalysisException(
"group_concat requires one or two parameters: " + this.toSql());
}
@@ -514,13 +539,14 @@ public class FunctionCallExpr extends Expr {
"group_concat requires first parameter to be of type STRING: " + this.toSql());
}
- if (children.size() == 2) {
+ if (children.size() - orderByElements.size() == 2) {
Expr arg1 = getChild(1);
if (!arg1.type.isStringType() && !arg1.type.isNull()) {
throw new AnalysisException(
"group_concat requires second parameter to be of type STRING: " + this.toSql());
}
}
+
return;
}
@@ -926,6 +952,15 @@ public class FunctionCallExpr extends Expr {
childTypes[2] = assignmentCompatibleType;
fn = getBuiltinFunction(fnName.getFunction(), childTypes,
Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
+ } else if (AggregateFunction.SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET.contains(
+ fnName.getFunction().toLowerCase())) {
+ // order by elements add as child like windows function. so if we get the
+ // param of arg, we need remove the order by elements
+ Type[] childTypes = collectChildReturnTypes();
+ Type[] newChildTypes = new Type[children.size() - orderByElements.size()];
+ System.arraycopy(childTypes, 0, newChildTypes, 0, newChildTypes.length);
+ fn = getBuiltinFunction(fnName.getFunction(), newChildTypes,
+ Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
} else {
// now first find table function in table function sets
if (isTableFnCall) {
@@ -1024,7 +1059,7 @@ public class FunctionCallExpr extends Expr {
Type[] args = fn.getArgs();
if (args.length > 0) {
// Implicitly cast all the children to match the function if necessary
- for (int i = 0; i < argTypes.length; ++i) {
+ for (int i = 0; i < argTypes.length - orderByElements.size(); ++i) {
// For varargs, we must compare with the last type in callArgs.argTypes.
int ix = Math.min(args.length - 1, i);
if (!argTypes[i].matchesType(args[ix]) && Config.use_date_v2_by_default
@@ -1327,7 +1362,6 @@ public class FunctionCallExpr extends Expr {
return result.toString();
}
- @Override
public void finalizeImplForNereids() throws AnalysisException {
// TODO: support other functions
// TODO: Supports type conversion to match the type of the function's parameters
@@ -1356,4 +1390,16 @@ public class FunctionCallExpr extends Expr {
public void setMergeForNereids(boolean isMergeAggFn) {
this.isMergeAggFn = isMergeAggFn;
}
+
+ public List<OrderByElement> getOrderByElements() {
+ return orderByElements;
+ }
+
+ public void setOrderByElements(List<OrderByElement> orderByElements) {
+ this.orderByElements = orderByElements;
+ }
+
+ private void setChildren() {
+ orderByElements.forEach(o -> addChild(o.getExpr()));
+ }
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectList.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectList.java
index 77a2084f79..ee950a032e 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectList.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectList.java
@@ -37,6 +37,7 @@ public class SelectList {
private boolean isDistinct;
private Map<String, String> optHints;
+ private List<OrderByElement> orderByElements;
// ///////////////////////////////////////
// BEGIN: Members that need to be reset()
@@ -90,6 +91,12 @@ public class SelectList {
}
}
+ public void setOrderByElements(List<OrderByElement> orderByElements) {
+ if (orderByElements != null) {
+ this.orderByElements = orderByElements;
+ }
+ }
+
public void reset() {
for (SelectListItem item : items) {
if (!item.isStar()) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java
index 72957c0eff..a58097e9eb 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java
@@ -55,6 +55,8 @@ public class AggregateFunction extends Function {
public static ImmutableSet<String> ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET =
ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx");
+ public static ImmutableSet<String> SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET = ImmutableSet.of("group_concat");
+
// Set if different from retType_, null otherwise.
private Type intermediateType;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java
index c8561b54dc..4cdd1aceee 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java
@@ -36,6 +36,7 @@ import org.apache.doris.thrift.TExplainLevel;
import org.apache.doris.thrift.TExpr;
import org.apache.doris.thrift.TPlanNode;
import org.apache.doris.thrift.TPlanNodeType;
+import org.apache.doris.thrift.TSortInfo;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
@@ -249,14 +250,27 @@ public class AggregationNode extends PlanNode {
protected void toThrift(TPlanNode msg) {
msg.node_type = TPlanNodeType.AGGREGATION_NODE;
List<TExpr> aggregateFunctions = Lists.newArrayList();
+ List<TSortInfo> aggSortInfos = Lists.newArrayList();
// only serialize agg exprs that are being materialized
for (FunctionCallExpr e : aggInfo.getMaterializedAggregateExprs()) {
aggregateFunctions.add(e.treeToThrift());
+ List<TExpr> orderingExpr = Lists.newArrayList();
+ List<Boolean> isAscs = Lists.newArrayList();
+ List<Boolean> nullFirsts = Lists.newArrayList();
+
+ e.getOrderByElements().forEach(o -> {
+ orderingExpr.add(o.getExpr().treeToThrift());
+ isAscs.add(o.getIsAsc());
+ nullFirsts.add(o.getNullsFirstParam());
+ });
+ aggSortInfos.add(new TSortInfo(orderingExpr, isAscs, nullFirsts));
}
+
msg.agg_node = new TAggregationNode(
aggregateFunctions,
aggInfo.getIntermediateTupleId().asInt(),
aggInfo.getOutputTupleId().asInt(), needsFinalize);
+ msg.agg_node.setAggSortInfos(aggSortInfos);
msg.agg_node.setUseStreamingPreaggregation(useStreamingPreagg);
List<Expr> groupingExprs = aggInfo.getGroupingExprs();
if (groupingExprs != null) {
diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift
index 3024bac292..b3823f1cf6 100644
--- a/gensrc/thrift/PlanNodes.thrift
+++ b/gensrc/thrift/PlanNodes.thrift
@@ -543,6 +543,7 @@ struct TAggregationNode {
// rows have been aggregated, and this node is not an intermediate node.
5: required bool need_finalize
6: optional bool use_streaming_preaggregation
+ 7: optional list<TSortInfo> agg_sort_infos
}
struct TRepeatNode {
diff --git a/regression-test/data/query/group_concat/test_group_concat.out b/regression-test/data/query/group_concat/test_group_concat.out
index 94f73cc536..e61bd4fd4f 100644
--- a/regression-test/data/query/group_concat/test_group_concat.out
+++ b/regression-test/data/query/group_concat/test_group_concat.out
@@ -5,3 +5,23 @@ false, false
-- !select --
false
+-- !select --
+\N \N
+103 255
+1001 1986, 1989
+1002 1989, 32767
+3021 1991, 1992, 32767
+5014 1985, 1991
+25699 1989
+2147483647 255, 1991, 32767, 32767
+
+-- !select --
+\N \N
+103 255
+1001 1986:1989
+1002 1989:32767
+3021 1991:1992:32767
+5014 1985:1991
+25699 1989
+2147483647 255:1991:32767:32767
+
diff --git a/regression-test/suites/query/group_concat/test_group_concat.groovy b/regression-test/suites/query/group_concat/test_group_concat.groovy
index 6bb57dea44..12d420cbe0 100644
--- a/regression-test/suites/query/group_concat/test_group_concat.groovy
+++ b/regression-test/suites/query/group_concat/test_group_concat.groovy
@@ -23,4 +23,12 @@ suite("test_group_concat", "query") {
qt_select """
SELECT group_concat(DISTINCT k6) FROM test_query_db.test where k6='false'
"""
+
+ qt_select """
+ SELECT abs(k3), group_concat(cast(abs(k2) as varchar) order by abs(k2), k1) FROM test_query_db.baseall group by abs(k3) order by abs(k3)
+ """
+
+ qt_select """
+ SELECT abs(k3), group_concat(cast(abs(k2) as varchar), ":" order by abs(k2), k1) FROM test_query_db.baseall group by abs(k3) order by abs(k3)
+ """
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org