You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by yi...@apache.org on 2022/07/28 01:13:04 UTC

[doris] branch master updated: [Vectorized] Support order by aggregate function (#11187)

This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 0b1d06bfd6 [Vectorized] Support order by aggregate function (#11187)
0b1d06bfd6 is described below

commit 0b1d06bfd66ec8ac62d94fa4e3a8578f54516f7e
Author: HappenLee <ha...@hotmail.com>
AuthorDate: Thu Jul 28 09:12:58 2022 +0800

    [Vectorized] Support order by aggregate function (#11187)
    
    
    Co-authored-by: lihaopeng <li...@baidu.com>
---
 .../aggregate_function_simple_factory.cpp          |  2 -
 .../aggregate_function_sort.cpp                    | 79 +++-------------------
 .../aggregate_functions/aggregate_function_sort.h  | 73 +++++++++-----------
 be/src/vec/exec/vaggregation_node.cpp              |  4 +-
 be/src/vec/exec/vanalytic_eval_node.cpp            |  2 +-
 be/src/vec/exprs/vectorized_agg_fn.cpp             | 36 ++++++++--
 be/src/vec/exprs/vectorized_agg_fn.h               | 10 ++-
 fe/fe-core/src/main/cup/sql_parser.cup             |  4 +-
 .../org/apache/doris/analysis/AggregateInfo.java   | 15 ++--
 .../apache/doris/analysis/FunctionCallExpr.java    | 56 +++++++++++++--
 .../java/org/apache/doris/analysis/SelectList.java |  7 ++
 .../apache/doris/catalog/AggregateFunction.java    |  2 +
 .../org/apache/doris/planner/AggregationNode.java  | 14 ++++
 gensrc/thrift/PlanNodes.thrift                     |  1 +
 .../data/query/group_concat/test_group_concat.out  | 20 ++++++
 .../query/group_concat/test_group_concat.groovy    |  8 +++
 16 files changed, 192 insertions(+), 141 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index c2169eb223..73779f8ffa 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -84,8 +84,6 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
         register_aggregate_function_window_lead_lag(instance);
         register_aggregate_function_HLL_union_agg(instance);
         register_aggregate_function_percentile_approx(instance);
-
-        register_aggregate_function_combinator_sort(instance);
     });
     return instance;
 }
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sort.cpp b/be/src/vec/aggregate_functions/aggregate_function_sort.cpp
index fbdb16df4f..b0566b829f 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_sort.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_sort.cpp
@@ -19,84 +19,21 @@
 
 #include "vec/aggregate_functions/aggregate_function_combinator.h"
 #include "vec/aggregate_functions/aggregate_function_simple_factory.h"
-#include "vec/aggregate_functions/helpers.h"
 #include "vec/common/typeid_cast.h"
 #include "vec/data_types/data_type_nullable.h"
-#include "vec/utils/template_helpers.hpp"
 
 namespace doris::vectorized {
 
-class AggregateFunctionCombinatorSort final : public IAggregateFunctionCombinator {
-private:
-    int _sort_column_number;
-
-public:
-    AggregateFunctionCombinatorSort(int sort_column_number)
-            : _sort_column_number(sort_column_number) {}
-
-    String get_name() const override { return "Sort"; }
-
-    DataTypes transform_arguments(const DataTypes& arguments) const override {
-        if (arguments.size() < _sort_column_number + 2) {
-            LOG(FATAL) << "Incorrect number of arguments for aggregate function with Sort, "
-                       << arguments.size() << " less than " << _sort_column_number + 2;
-        }
-
-        DataTypes nested_types;
-        nested_types.assign(arguments.begin(), arguments.end() - 1 - _sort_column_number);
-        return nested_types;
+AggregateFunctionPtr transform_to_sort_agg_function(const AggregateFunctionPtr& nested_function,
+                                                    const DataTypes& arguments,
+                                                    const SortDescription& sort_desc) {
+    DCHECK(nested_function != nullptr);
+    if (nested_function == nullptr) {
+        return nullptr;
     }
 
-    template <int sort_column_number>
-    struct Reducer {
-        static void run(AggregateFunctionPtr& function, const AggregateFunctionPtr& nested_function,
-                        const DataTypes& arguments) {
-            function = std::make_shared<
-                    AggregateFunctionSort<sort_column_number, AggregateFunctionSortData>>(
-                    nested_function, arguments);
-        }
-    };
-
-    AggregateFunctionPtr transform_aggregate_function(
-            const AggregateFunctionPtr& nested_function, const DataTypes& arguments,
-            const Array& params, const bool result_is_nullable) const override {
-        DCHECK(nested_function != nullptr);
-        if (nested_function == nullptr) {
-            return nullptr;
-        }
-
-        AggregateFunctionPtr function = nullptr;
-        constexpr_int_match<1, 3, Reducer>::run(_sort_column_number, function, nested_function,
-                                                arguments);
-
-        return function;
-    }
+    return std::make_shared<AggregateFunctionSort<AggregateFunctionSortData>>(nested_function,
+                                                                              arguments, sort_desc);
 };
 
-const std::string SORT_FUNCTION_PREFIX = "sort_";
-
-void register_aggregate_function_combinator_sort(AggregateFunctionSimpleFactory& factory) {
-    AggregateFunctionCreator creator = [&](const std::string& name, const DataTypes& types,
-                                           const Array& params, const bool result_is_nullable) {
-        int sort_column_number = std::stoi(name.substr(SORT_FUNCTION_PREFIX.size(), 2));
-        auto nested_function_name = name.substr(SORT_FUNCTION_PREFIX.size() + 2);
-
-        auto function_combinator =
-                std::make_shared<AggregateFunctionCombinatorSort>(sort_column_number);
-
-        auto transform_arguments = function_combinator->transform_arguments(types);
-
-        auto nested_function =
-                factory.get(nested_function_name, transform_arguments, params, result_is_nullable);
-        return function_combinator->transform_aggregate_function(nested_function, types, params,
-                                                                 result_is_nullable);
-    };
-
-    for (char c = '1'; c <= '3'; c++) {
-        factory.register_distinct_function_combinator(creator, SORT_FUNCTION_PREFIX + c + "_",
-                                                      false);
-        factory.register_distinct_function_combinator(creator, SORT_FUNCTION_PREFIX + c + "_",
-                                                      true);
-    }
-}
 } // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_sort.h b/be/src/vec/aggregate_functions/aggregate_function_sort.h
index 5cad555c37..cd16d29770 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_sort.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_sort.h
@@ -18,6 +18,7 @@
 #pragma once
 
 #include <string>
+#include <utility>
 
 #include "vec/aggregate_functions/aggregate_function.h"
 #include "vec/aggregate_functions/key_holder_helpers.h"
@@ -35,10 +36,17 @@
 
 namespace doris::vectorized {
 
-template <int sort_column_size>
 struct AggregateFunctionSortData {
+    const SortDescription sort_desc;
+    Block block;
+
+    // The construct only support the template compiler, useless
+    AggregateFunctionSortData() {};
+    AggregateFunctionSortData(SortDescription sort_desc, const Block& block)
+            : sort_desc(std::move(sort_desc)), block(block.clone_empty()) {}
+
     void merge(const AggregateFunctionSortData& rhs) {
-        if (block.is_empty_column()) {
+        if (block.rows() == 0) {
             block = rhs.block;
         } else {
             for (size_t i = 0; i < block.columns(); i++) {
@@ -78,45 +86,18 @@ struct AggregateFunctionSortData {
         }
     }
 
-    void sort() {
-        size_t sort_desc_idx = block.columns() - sort_column_size - 1;
-        StringRef desc_str =
-                block.get_by_position(sort_desc_idx).column->assume_mutable()->get_data_at(0);
-        DCHECK(sort_column_size == desc_str.size);
-
-        SortDescription sort_description(sort_column_size);
-        for (size_t i = 0; i < sort_column_size; i++) {
-            sort_description[i].column_number = sort_desc_idx + 1 + i;
-            sort_description[i].direction = (desc_str.data[i] == '0') ? 1 : -1;
-            sort_description[i].nulls_direction = sort_description[i].direction;
-        }
-
-        sort_block(block, sort_description, block.rows());
-    }
-
-    void try_init(const DataTypes& _arguments) {
-        if (!block.is_empty_column()) {
-            return;
-        }
-
-        for (auto type : _arguments) {
-            block.insert({type, ""});
-        }
-    }
-
-    Block block;
+    void sort() { sort_block(block, sort_desc, block.rows()); }
 };
 
-template <int sort_column_size, template <int> typename Data>
+template <typename Data>
 class AggregateFunctionSort
-        : public IAggregateFunctionDataHelper<Data<sort_column_size>,
-                                              AggregateFunctionSort<sort_column_size, Data>> {
-    using DataReal = Data<sort_column_size>;
-
+        : public IAggregateFunctionDataHelper<Data, AggregateFunctionSort<Data>> {
 private:
-    static constexpr auto prefix_size = sizeof(DataReal);
+    static constexpr auto prefix_size = sizeof(Data);
     AggregateFunctionPtr _nested_func;
     DataTypes _arguments;
+    const SortDescription& _sort_desc;
+    Block _block;
 
     AggregateDataPtr get_nested_place(AggregateDataPtr __restrict place) const noexcept {
         return place + prefix_size;
@@ -127,15 +108,20 @@ private:
     }
 
 public:
-    AggregateFunctionSort(AggregateFunctionPtr nested_func, const DataTypes& arguments)
-            : IAggregateFunctionDataHelper<DataReal, AggregateFunctionSort>(
+    AggregateFunctionSort(const AggregateFunctionPtr& nested_func, const DataTypes& arguments,
+                          const SortDescription& sort_desc)
+            : IAggregateFunctionDataHelper<Data, AggregateFunctionSort>(
                       arguments, nested_func->get_parameters()),
               _nested_func(nested_func),
-              _arguments(arguments) {}
+              _arguments(arguments),
+              _sort_desc(sort_desc) {
+        for (const auto& type : _arguments) {
+            _block.insert({type, ""});
+        }
+    }
 
     void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
              Arena* arena) const override {
-        this->data(place).try_init(_arguments);
         this->data(place).add(columns, _arguments.size(), row_num);
     }
 
@@ -159,7 +145,7 @@ public:
             this->data(place).sort();
 
             ColumnRawPtrs arguments_nested;
-            for (int i = 0; i < _arguments.size() - 1 - sort_column_size; i++) {
+            for (int i = 0; i < _arguments.size() - _sort_desc.size(); i++) {
                 arguments_nested.emplace_back(
                         this->data(place).block.get_by_position(i).column.get());
             }
@@ -176,12 +162,12 @@ public:
     size_t align_of_data() const override { return _nested_func->align_of_data(); }
 
     void create(AggregateDataPtr __restrict place) const override {
-        new (place) DataReal;
+        new (place) Data(_sort_desc, _block);
         _nested_func->create(get_nested_place(place));
     }
 
     void destroy(AggregateDataPtr __restrict place) const noexcept override {
-        this->data(place).~DataReal();
+        this->data(place).~Data();
         _nested_func->destroy(get_nested_place(place));
     }
 
@@ -190,4 +176,7 @@ public:
     DataTypePtr get_return_type() const override { return _nested_func->get_return_type(); }
 };
 
+AggregateFunctionPtr transform_to_sort_agg_function(const AggregateFunctionPtr& nested_function,
+                                                    const DataTypes& arguments,
+                                                    const SortDescription& sort_desc);
 } // namespace doris::vectorized
diff --git a/be/src/vec/exec/vaggregation_node.cpp b/be/src/vec/exec/vaggregation_node.cpp
index e7008957eb..d4325bc17e 100644
--- a/be/src/vec/exec/vaggregation_node.cpp
+++ b/be/src/vec/exec/vaggregation_node.cpp
@@ -116,8 +116,8 @@ Status AggregationNode::init(const TPlanNode& tnode, RuntimeState* state) {
     _aggregate_evaluators.reserve(tnode.agg_node.aggregate_functions.size());
     for (int i = 0; i < tnode.agg_node.aggregate_functions.size(); ++i) {
         AggFnEvaluator* evaluator = nullptr;
-        RETURN_IF_ERROR(
-                AggFnEvaluator::create(_pool, tnode.agg_node.aggregate_functions[i], &evaluator));
+        RETURN_IF_ERROR(AggFnEvaluator::create(_pool, tnode.agg_node.aggregate_functions[i],
+                                               tnode.agg_node.agg_sort_infos[i], &evaluator));
         _aggregate_evaluators.push_back(evaluator);
     }
 
diff --git a/be/src/vec/exec/vanalytic_eval_node.cpp b/be/src/vec/exec/vanalytic_eval_node.cpp
index 3906062cc4..dde7cb8453 100644
--- a/be/src/vec/exec/vanalytic_eval_node.cpp
+++ b/be/src/vec/exec/vanalytic_eval_node.cpp
@@ -127,7 +127,7 @@ Status VAnalyticEvalNode::init(const TPlanNode& tnode, RuntimeState* state) {
 
         AggFnEvaluator* evaluator = nullptr;
         RETURN_IF_ERROR(
-                AggFnEvaluator::create(_pool, analytic_node.analytic_functions[i], &evaluator));
+                AggFnEvaluator::create(_pool, analytic_node.analytic_functions[i], {}, &evaluator));
         _agg_functions.emplace_back(evaluator);
         for (size_t j = 0; j < _agg_expr_ctxs[i].size(); ++j) {
             _agg_intput_columns[i][j] = _agg_expr_ctxs[i][j]->root()->data_type()->create_column();
diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp b/be/src/vec/exprs/vectorized_agg_fn.cpp
index 427b90ec2f..527cb40e18 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.cpp
+++ b/be/src/vec/exprs/vectorized_agg_fn.cpp
@@ -23,11 +23,13 @@
 #include "vec/aggregate_functions/aggregate_function_java_udaf.h"
 #include "vec/aggregate_functions/aggregate_function_rpc.h"
 #include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/aggregate_function_sort.h"
 #include "vec/columns/column_nullable.h"
 #include "vec/core/materialize_block.h"
 #include "vec/data_types/data_type_factory.hpp"
 #include "vec/data_types/data_type_nullable.h"
 #include "vec/exprs/vexpr.h"
+
 namespace doris::vectorized {
 
 AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
@@ -46,12 +48,14 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
     _data_type = DataTypeFactory::instance().create_data_type(_return_type, nullable);
 
     auto& param_types = desc.agg_expr.param_types;
-    for (auto raw_type : param_types) {
-        _argument_types.push_back(DataTypeFactory::instance().create_data_type(raw_type));
+    for (int i = 0; i < param_types.size(); i++) {
+        _argument_types_with_sort.push_back(
+                DataTypeFactory::instance().create_data_type(param_types[i]));
     }
 }
 
-Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, AggFnEvaluator** result) {
+Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info,
+                              AggFnEvaluator** result) {
     *result = pool->add(new AggFnEvaluator(desc.nodes[0]));
     auto& agg_fn_evaluator = *result;
     int node_idx = 0;
@@ -63,6 +67,22 @@ Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, AggFnEvaluato
                 VExpr::create_tree_from_thrift(pool, desc.nodes, nullptr, &node_idx, &expr, &ctx));
         agg_fn_evaluator->_input_exprs_ctxs.push_back(ctx);
     }
+
+    auto sort_size = sort_info.ordering_exprs.size();
+    auto real_arguments_size = agg_fn_evaluator->_argument_types_with_sort.size() - sort_size;
+    // Child arguments conatins [real arguments, order by arguments], we pass the arguments
+    // to the order by functions
+    for (int i = 0; i < sort_size; ++i) {
+        agg_fn_evaluator->_sort_description.emplace_back(real_arguments_size + i,
+                                                         sort_info.is_asc_order[i] == true,
+                                                         sort_info.nulls_first[i] == true);
+    }
+
+    // Pass the real arguments to get functions
+    for (int i = 0; i < real_arguments_size; ++i) {
+        agg_fn_evaluator->_real_argument_types.emplace_back(
+                agg_fn_evaluator->_argument_types_with_sort[i]);
+    }
     return Status::OK();
 }
 
@@ -87,20 +107,24 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc, M
 
     if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) {
 #ifdef LIBJVM
-        _function = AggregateJavaUdaf::create(_fn, _argument_types, {}, _data_type);
+        _function = AggregateJavaUdaf::create(_fn, _real_argument_types, {}, _data_type);
 #else
         return Status::InternalError("Java UDAF is disabled since no libjvm is found!");
 #endif
     } else if (_fn.binary_type == TFunctionBinaryType::RPC) {
-        _function = AggregateRpcUdaf::create(_fn, _argument_types, {}, _data_type);
+        _function = AggregateRpcUdaf::create(_fn, _real_argument_types, {}, _data_type);
     } else {
         _function = AggregateFunctionSimpleFactory::instance().get(
-                _fn.name.function_name, _argument_types, {}, _data_type->is_nullable());
+                _fn.name.function_name, _real_argument_types, {}, _data_type->is_nullable());
     }
     if (_function == nullptr) {
         return Status::InternalError("Agg Function {} is not implemented", _fn.name.function_name);
     }
 
+    if (!_sort_description.empty()) {
+        _function = transform_to_sort_agg_function(_function, _argument_types_with_sort,
+                                                   _sort_description);
+    }
     _expr_name = fmt::format("{}({})", _fn.name.function_name, child_expr_name);
     return Status::OK();
 }
diff --git a/be/src/vec/exprs/vectorized_agg_fn.h b/be/src/vec/exprs/vectorized_agg_fn.h
index a541257487..52098f0d8f 100644
--- a/be/src/vec/exprs/vectorized_agg_fn.h
+++ b/be/src/vec/exprs/vectorized_agg_fn.h
@@ -20,6 +20,7 @@
 #include "util/runtime_profile.h"
 #include "vec/aggregate_functions/aggregate_function.h"
 #include "vec/core/block.h"
+#include "vec/core/sort_description.h"
 #include "vec/data_types/data_type.h"
 #include "vec/exprs/vexpr_context.h"
 
@@ -29,7 +30,8 @@ class SlotDescriptor;
 namespace vectorized {
 class AggFnEvaluator {
 public:
-    static Status create(ObjectPool* pool, const TExpr& desc, AggFnEvaluator** result);
+    static Status create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info,
+                         AggFnEvaluator** result);
 
     Status prepare(RuntimeState* state, const RowDescriptor& desc, MemPool* pool,
                    const SlotDescriptor* intermediate_slot_desc,
@@ -80,7 +82,9 @@ private:
 
     void _calc_argment_columns(Block* block);
 
-    DataTypes _argument_types;
+    DataTypes _argument_types_with_sort;
+    DataTypes _real_argument_types;
+
     const TypeDescriptor _return_type;
 
     const SlotDescriptor* _intermediate_slot_desc;
@@ -93,6 +97,8 @@ private:
     // input context
     std::vector<VExprContext*> _input_exprs_ctxs;
 
+    SortDescription _sort_description;
+
     DataTypePtr _data_type;
 
     AggregateFunctionPtr _function;
diff --git a/fe/fe-core/src/main/cup/sql_parser.cup b/fe/fe-core/src/main/cup/sql_parser.cup
index 556ab0785f..1042783da3 100644
--- a/fe/fe-core/src/main/cup/sql_parser.cup
+++ b/fe/fe-core/src/main/cup/sql_parser.cup
@@ -4950,8 +4950,8 @@ non_pred_expr ::=
   {: RESULT = new FunctionCallExpr(fn_name, exprs); :}
   //| function_name:fn_name LPAREN RPAREN
   //{: RESULT = new FunctionCallExpr(fn_name, new ArrayList<Expr>()); :}
-  //| function_name:fn_name LPAREN function_params:params RPAREN
-  //{: RESULT = new FunctionCallExpr(fn_name, params); :}
+  | function_name:fn_name LPAREN function_params:params order_by_clause:o RPAREN
+  {: RESULT = new FunctionCallExpr(fn_name, params, o); :}
   | analytic_expr:e
   {: RESULT = e; :}
   /* Since "IF" is a keyword, need to special case this function */
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
index a0152a8e8a..79e01e0cae 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
@@ -496,6 +496,8 @@ public final class AggregateInfo extends AggregateInfoBase {
             FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall(
                     inputExpr, Lists.newArrayList(aggExprParam), inputExpr.getFnParams().exprs());
             aggExpr.analyzeNoThrow(analyzer);
+            // do not need analyze in merge stage, just do mark for BE get right function
+            aggExpr.setOrderByElements(inputExpr.getOrderByElements());
             aggExprs.add(aggExpr);
         }
 
@@ -621,7 +623,6 @@ public final class AggregateInfo extends AggregateInfoBase {
         }
         Preconditions.checkState(
                 secondPhaseAggExprs.size() == aggregateExprs.size() + distinctAggExprs.size());
-
         for (FunctionCallExpr aggExpr : secondPhaseAggExprs) {
             aggExpr.analyzeNoThrow(analyzer);
             Preconditions.checkState(aggExpr.isAggregateFunction());
@@ -649,18 +650,16 @@ public final class AggregateInfo extends AggregateInfoBase {
         int numDistinctParams = 0;
         if (!isMultiDistinct) {
             numDistinctParams = distinctAggExprs.get(0).getChildren().size();
-            // If we are counting distinct params of group_concat, we cannot include the custom
-            // separator since it is not a distinct param.
-            if (distinctAggExprs.get(0).getFnName().getFunction().equalsIgnoreCase("group_concat")
-                    && numDistinctParams == 2) {
-                --numDistinctParams;
-            }
         } else {
             for (int i = 0; i < distinctAggExprs.size(); i++) {
                 numDistinctParams += distinctAggExprs.get(i).getChildren().size();
             }
         }
-
+        // If we are counting distinct params of group_concat, we cannot include the custom
+        // separator since it is not a distinct param.
+        if (distinctAggExprs.get(0).getFnName().getFunction().equalsIgnoreCase("group_concat")) {
+            numDistinctParams = 1;
+        }
         int numOrigGroupingExprs = inputAggInfo.getGroupingExprs().size() - numDistinctParams;
         Preconditions.checkState(
                 slotDescs.size() == numOrigGroupingExprs + distinctAggExprs.size()
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
index ee65f35469..6038d5e3f1 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
@@ -90,9 +90,10 @@ public class FunctionCallExpr extends Expr {
     // private BuiltinAggregateFunction.Operator aggOp;
     private FunctionParams fnParams;
 
-    // represent original parament from aggregate function
     private FunctionParams aggFnParams;
 
+    private List<OrderByElement> orderByElements = Lists.newArrayList();
+
     // check analytic function
     private boolean isAnalyticFnCall = false;
     // check table function
@@ -155,6 +156,27 @@ public class FunctionCallExpr extends Expr {
         this(fnName, params, false);
     }
 
+    public FunctionCallExpr(
+            FunctionName fnName, FunctionParams params, List<OrderByElement> orderByElements) throws AnalysisException {
+        this(fnName, params, false);
+        this.orderByElements = orderByElements;
+        if (!orderByElements.isEmpty()) {
+            if (!VectorizedUtil.isVectorized()) {
+                throw new AnalysisException(
+                    "ORDER BY for arguments only support in vec exec engine");
+            } else if (!AggregateFunction.SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET.contains(
+                    fnName.getFunction().toLowerCase())) {
+                throw new AnalysisException(
+                    "ORDER BY not support for the function:" + fnName.getFunction().toLowerCase());
+            } else if (params.isDistinct()) {
+                throw new AnalysisException(
+                    "ORDER BY not support for the distinct, support in the furture:"
+                        + fnName.getFunction().toLowerCase());
+            }
+        }
+        setChildren();
+    }
+
     private FunctionCallExpr(
             FunctionName fnName, FunctionParams params, boolean isMergeAggFn) {
         super();
@@ -187,6 +209,7 @@ public class FunctionCallExpr extends Expr {
     protected FunctionCallExpr(FunctionCallExpr other) {
         super(other);
         fnName = other.fnName;
+        orderByElements = other.orderByElements;
         isAnalyticFnCall = other.isAnalyticFnCall;
         //   aggOp = other.aggOp;
         // fnParams = other.fnParams;
@@ -289,6 +312,8 @@ public class FunctionCallExpr extends Expr {
                     || fnName.getFunction().equalsIgnoreCase("sm4_decrypt")
                     || fnName.getFunction().equalsIgnoreCase("sm4_encrypt"))) {
                 result.add("\'***\'");
+            } else if (orderByElements.size() > 0 && i == len - orderByElements.size()) {
+                result.add("ORDER BY " + children.get(i).toSql());
             } else {
                 result.add(children.get(i).toSql());
             }
@@ -503,7 +528,7 @@ public class FunctionCallExpr extends Expr {
         }
 
         if (fnName.getFunction().equalsIgnoreCase("group_concat")) {
-            if (children.size() > 2 || children.isEmpty()) {
+            if (children.size() - orderByElements.size() > 2 || children.isEmpty()) {
                 throw new AnalysisException(
                         "group_concat requires one or two parameters: " + this.toSql());
             }
@@ -514,13 +539,14 @@ public class FunctionCallExpr extends Expr {
                         "group_concat requires first parameter to be of type STRING: " + this.toSql());
             }
 
-            if (children.size() == 2) {
+            if (children.size() - orderByElements.size() == 2) {
                 Expr arg1 = getChild(1);
                 if (!arg1.type.isStringType() && !arg1.type.isNull()) {
                     throw new AnalysisException(
                             "group_concat requires second parameter to be of type STRING: " + this.toSql());
                 }
             }
+
             return;
         }
 
@@ -926,6 +952,15 @@ public class FunctionCallExpr extends Expr {
             childTypes[2] = assignmentCompatibleType;
             fn = getBuiltinFunction(fnName.getFunction(), childTypes,
                     Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
+        } else if (AggregateFunction.SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET.contains(
+                fnName.getFunction().toLowerCase())) {
+            // order by elements add as child like windows function. so if we get the
+            // param of arg, we need remove the order by elements
+            Type[] childTypes = collectChildReturnTypes();
+            Type[] newChildTypes = new Type[children.size() - orderByElements.size()];
+            System.arraycopy(childTypes, 0, newChildTypes, 0, newChildTypes.length);
+            fn = getBuiltinFunction(fnName.getFunction(), newChildTypes,
+                Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
         } else {
             // now first find table function in table function sets
             if (isTableFnCall) {
@@ -1024,7 +1059,7 @@ public class FunctionCallExpr extends Expr {
             Type[] args = fn.getArgs();
             if (args.length > 0) {
                 // Implicitly cast all the children to match the function if necessary
-                for (int i = 0; i < argTypes.length; ++i) {
+                for (int i = 0; i < argTypes.length - orderByElements.size(); ++i) {
                     // For varargs, we must compare with the last type in callArgs.argTypes.
                     int ix = Math.min(args.length - 1, i);
                     if (!argTypes[i].matchesType(args[ix]) && Config.use_date_v2_by_default
@@ -1327,7 +1362,6 @@ public class FunctionCallExpr extends Expr {
         return result.toString();
     }
 
-    @Override
     public void finalizeImplForNereids() throws AnalysisException {
         // TODO: support other functions
         // TODO: Supports type conversion to match the type of the function's parameters
@@ -1356,4 +1390,16 @@ public class FunctionCallExpr extends Expr {
     public void setMergeForNereids(boolean isMergeAggFn) {
         this.isMergeAggFn = isMergeAggFn;
     }
+
+    public List<OrderByElement> getOrderByElements() {
+        return orderByElements;
+    }
+
+    public void setOrderByElements(List<OrderByElement> orderByElements) {
+        this.orderByElements = orderByElements;
+    }
+
+    private void setChildren() {
+        orderByElements.forEach(o -> addChild(o.getExpr()));
+    }
 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectList.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectList.java
index 77a2084f79..ee950a032e 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectList.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/SelectList.java
@@ -37,6 +37,7 @@ public class SelectList {
 
     private boolean isDistinct;
     private Map<String, String> optHints;
+    private List<OrderByElement> orderByElements;
 
     // ///////////////////////////////////////
     // BEGIN: Members that need to be reset()
@@ -90,6 +91,12 @@ public class SelectList {
         }
     }
 
+    public void setOrderByElements(List<OrderByElement> orderByElements) {
+        if (orderByElements != null) {
+            this.orderByElements = orderByElements;
+        }
+    }
+
     public void reset() {
         for (SelectListItem item : items) {
             if (!item.isStar()) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java
index 72957c0eff..a58097e9eb 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java
@@ -55,6 +55,8 @@ public class AggregateFunction extends Function {
     public static ImmutableSet<String> ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET =
             ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx");
 
+    public static ImmutableSet<String> SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET = ImmutableSet.of("group_concat");
+
     // Set if different from retType_, null otherwise.
     private Type intermediateType;
 
diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java
index c8561b54dc..4cdd1aceee 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/AggregationNode.java
@@ -36,6 +36,7 @@ import org.apache.doris.thrift.TExplainLevel;
 import org.apache.doris.thrift.TExpr;
 import org.apache.doris.thrift.TPlanNode;
 import org.apache.doris.thrift.TPlanNodeType;
+import org.apache.doris.thrift.TSortInfo;
 
 import com.google.common.base.MoreObjects;
 import com.google.common.base.Preconditions;
@@ -249,14 +250,27 @@ public class AggregationNode extends PlanNode {
     protected void toThrift(TPlanNode msg) {
         msg.node_type = TPlanNodeType.AGGREGATION_NODE;
         List<TExpr> aggregateFunctions = Lists.newArrayList();
+        List<TSortInfo> aggSortInfos = Lists.newArrayList();
         // only serialize agg exprs that are being materialized
         for (FunctionCallExpr e : aggInfo.getMaterializedAggregateExprs()) {
             aggregateFunctions.add(e.treeToThrift());
+            List<TExpr> orderingExpr = Lists.newArrayList();
+            List<Boolean> isAscs = Lists.newArrayList();
+            List<Boolean> nullFirsts = Lists.newArrayList();
+
+            e.getOrderByElements().forEach(o -> {
+                orderingExpr.add(o.getExpr().treeToThrift());
+                isAscs.add(o.getIsAsc());
+                nullFirsts.add(o.getNullsFirstParam());
+            });
+            aggSortInfos.add(new TSortInfo(orderingExpr, isAscs, nullFirsts));
         }
+
         msg.agg_node = new TAggregationNode(
                 aggregateFunctions,
                 aggInfo.getIntermediateTupleId().asInt(),
                 aggInfo.getOutputTupleId().asInt(), needsFinalize);
+        msg.agg_node.setAggSortInfos(aggSortInfos);
         msg.agg_node.setUseStreamingPreaggregation(useStreamingPreagg);
         List<Expr> groupingExprs = aggInfo.getGroupingExprs();
         if (groupingExprs != null) {
diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift
index 3024bac292..b3823f1cf6 100644
--- a/gensrc/thrift/PlanNodes.thrift
+++ b/gensrc/thrift/PlanNodes.thrift
@@ -543,6 +543,7 @@ struct TAggregationNode {
   // rows have been aggregated, and this node is not an intermediate node.
   5: required bool need_finalize
   6: optional bool use_streaming_preaggregation
+  7: optional list<TSortInfo> agg_sort_infos
 }
 
 struct TRepeatNode {
diff --git a/regression-test/data/query/group_concat/test_group_concat.out b/regression-test/data/query/group_concat/test_group_concat.out
index 94f73cc536..e61bd4fd4f 100644
--- a/regression-test/data/query/group_concat/test_group_concat.out
+++ b/regression-test/data/query/group_concat/test_group_concat.out
@@ -5,3 +5,23 @@ false, false
 -- !select --
 false
 
+-- !select --
+\N	\N
+103	255
+1001	1986, 1989
+1002	1989, 32767
+3021	1991, 1992, 32767
+5014	1985, 1991
+25699	1989
+2147483647	255, 1991, 32767, 32767
+
+-- !select --
+\N	\N
+103	255
+1001	1986:1989
+1002	1989:32767
+3021	1991:1992:32767
+5014	1985:1991
+25699	1989
+2147483647	255:1991:32767:32767
+
diff --git a/regression-test/suites/query/group_concat/test_group_concat.groovy b/regression-test/suites/query/group_concat/test_group_concat.groovy
index 6bb57dea44..12d420cbe0 100644
--- a/regression-test/suites/query/group_concat/test_group_concat.groovy
+++ b/regression-test/suites/query/group_concat/test_group_concat.groovy
@@ -23,4 +23,12 @@ suite("test_group_concat", "query") {
     qt_select """
                 SELECT group_concat(DISTINCT k6) FROM test_query_db.test where k6='false'
               """
+
+    qt_select """
+                SELECT abs(k3), group_concat(cast(abs(k2) as varchar) order by abs(k2), k1) FROM test_query_db.baseall group by abs(k3) order by abs(k3)
+              """
+              
+    qt_select """
+                SELECT abs(k3), group_concat(cast(abs(k2) as varchar), ":" order by abs(k2), k1) FROM test_query_db.baseall group by abs(k3) order by abs(k3)
+              """
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org