You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2022/04/20 06:46:24 UTC

[incubator-doris] branch master updated: [feature](vectorized)(function) Support min_by/max_by function. (#8623)

This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 1b4cd76847 [feature](vectorized)(function) Support min_by/max_by function. (#8623)
1b4cd76847 is described below

commit 1b4cd7684765342df3d1e5aa8c4675c68c6c097e
Author: zhannngchen <48...@users.noreply.github.com>
AuthorDate: Wed Apr 20 14:46:19 2022 +0800

    [feature](vectorized)(function) Support min_by/max_by function. (#8623)
    
    Support min_by/max_by on vectorized engine.
---
 be/src/vec/CMakeLists.txt                          |   1 +
 .../aggregate_function_min_max_by.cpp              | 116 +++++++++++++++
 .../aggregate_function_min_max_by.h                | 161 +++++++++++++++++++++
 .../aggregate_function_simple_factory.cpp          |   2 +
 be/test/CMakeLists.txt                             |   1 +
 .../aggregate_functions/agg_min_max_by_test.cpp    | 102 +++++++++++++
 .../sql-functions/aggregate-functions/max_by.md    |  56 +++++++
 .../sql-functions/aggregate-functions/min_by.md    |  56 +++++++
 .../sql-functions/aggregate-functions/max_by.md    |  56 +++++++
 .../sql-functions/aggregate-functions/min_by.md    |  56 +++++++
 .../org/apache/doris/analysis/AggregateInfo.java   |  13 +-
 .../java/org/apache/doris/catalog/FunctionSet.java |  14 ++
 12 files changed, 631 insertions(+), 3 deletions(-)

diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt
index 3e789b7365..63f07d1408 100644
--- a/be/src/vec/CMakeLists.txt
+++ b/be/src/vec/CMakeLists.txt
@@ -26,6 +26,7 @@ set(VEC_FILES
   aggregate_functions/aggregate_function_distinct.cpp
   aggregate_functions/aggregate_function_sum.cpp
   aggregate_functions/aggregate_function_min_max.cpp
+  aggregate_functions/aggregate_function_min_max_by.cpp
   aggregate_functions/aggregate_function_null.cpp
   aggregate_functions/aggregate_function_uniq.cpp
   aggregate_functions/aggregate_function_hll_union_agg.cpp
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
new file mode 100644
index 0000000000..765e1f05ee
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
@@ -0,0 +1,116 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "vec/aggregate_functions/aggregate_function_min_max.h"
+#include "vec/aggregate_functions/aggregate_function_min_max_by.h"
+
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/factory_helpers.h"
+#include "vec/aggregate_functions/helpers.h"
+
+namespace doris::vectorized {
+
+/// min_by, max_by
+template <template <typename, bool> class AggregateFunctionTemplate,
+          template <typename, typename> class Data, typename VT>
+static IAggregateFunction* create_aggregate_function_min_max_by_impl(
+        const DataTypes& argument_types) {
+    const DataTypePtr& value_arg_type = argument_types[0];
+    const DataTypePtr& key_arg_type = argument_types[1];
+
+    WhichDataType which(key_arg_type);
+#define DISPATCH(TYPE)                                                                     \
+    if (which.idx == TypeIndex::TYPE)                                                      \
+        return new AggregateFunctionTemplate<Data<VT, SingleValueDataFixed<TYPE>>, false>( \
+                value_arg_type, key_arg_type);
+    FOR_NUMERIC_TYPES(DISPATCH)
+#undef DISPATCH
+    if (which.idx == TypeIndex::String) {
+        return new AggregateFunctionTemplate<Data<VT, SingleValueDataString>, false>(value_arg_type,
+                                                                                     key_arg_type);
+    }
+    if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::Date) {
+        return new AggregateFunctionTemplate<Data<VT, SingleValueDataFixed<Int64>>, false>(
+                value_arg_type, key_arg_type);
+    }
+    if (which.idx == TypeIndex::Decimal128) {
+        return new AggregateFunctionTemplate<Data<VT, SingleValueDataFixed<DecimalV2Value>>, false>(
+                value_arg_type, key_arg_type);
+    }
+    return nullptr;
+}
+
+/// min_by, max_by
+template <template <typename, bool> class AggregateFunctionTemplate,
+          template <typename, typename> class Data>
+static IAggregateFunction* create_aggregate_function_min_max_by(const String& name,
+                                                                const DataTypes& argument_types,
+                                                                const Array& parameters) {
+    assert_no_parameters(name, parameters);
+    assert_binary(name, argument_types);
+
+    const DataTypePtr& value_arg_type = argument_types[0];
+
+    WhichDataType which(value_arg_type);
+#define DISPATCH(TYPE)                                                                    \
+    if (which.idx == TypeIndex::TYPE)                                                     \
+        return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data, \
+                                                         SingleValueDataFixed<TYPE>>(     \
+                argument_types);
+    FOR_NUMERIC_TYPES(DISPATCH)
+#undef DISPATCH
+    if (which.idx == TypeIndex::String) {
+        return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data,
+                                                         SingleValueDataString>(argument_types);
+    }
+    if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::Date) {
+        return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data,
+                                                         SingleValueDataFixed<Int64>>(
+                argument_types);
+    }
+    if (which.idx == TypeIndex::Decimal128) {
+        return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data,
+                                                         SingleValueDataFixed<DecimalV2Value>>(
+                argument_types);
+    }
+    return nullptr;
+}
+
+AggregateFunctionPtr create_aggregate_function_max_by(const std::string& name,
+                                                      const DataTypes& argument_types,
+                                                      const Array& parameters,
+                                                      const bool result_is_nullable) {
+    return AggregateFunctionPtr(create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy,
+                                                                     AggregateFunctionMaxByData>(
+            name, argument_types, parameters));
+}
+
+AggregateFunctionPtr create_aggregate_function_min_by(const std::string& name,
+                                                      const DataTypes& argument_types,
+                                                      const Array& parameters,
+                                                      const bool result_is_nullable) {
+    return AggregateFunctionPtr(create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy,
+                                                                     AggregateFunctionMinByData>(
+            name, argument_types, parameters));
+}
+
+void register_aggregate_function_min_max_by(AggregateFunctionSimpleFactory& factory) {
+    factory.register_function("max_by", create_aggregate_function_max_by);
+    factory.register_function("min_by", create_aggregate_function_min_by);
+}
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
new file mode 100644
index 0000000000..5fb061c3e9
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
@@ -0,0 +1,161 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "common/logging.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/column_decimal.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/assert_cast.h"
+#include "vec/io/io_helper.h"
+
+namespace doris::vectorized {
+template <typename VT, typename KT>
+struct AggregateFunctionMinMaxByBaseData {
+protected:
+    VT value;
+    KT key;
+
+public:
+    void insert_result_into(IColumn& to) const { value.insert_result_into(to); }
+
+    void reset() {
+        value.reset();
+        key.reset();
+    }
+    void write(BufferWritable& buf) const {
+        value.write(buf);
+        key.write(buf);
+    }
+
+    void read(BufferReadable& buf) {
+        value.read(buf);
+        key.read(buf);
+    }
+};
+
+template <typename VT, typename KT>
+struct AggregateFunctionMaxByData : public AggregateFunctionMinMaxByBaseData<VT, KT> {
+    using Self = AggregateFunctionMaxByData;
+    bool change_if_better(const IColumn& value_column, const IColumn& key_column, size_t row_num,
+                          Arena* arena) {
+        if (this->key.change_if_greater(key_column, row_num, arena)) {
+            this->value.change(value_column, row_num, arena);
+            return true;
+        }
+        return false;
+    }
+
+    bool change_if_better(const Self& to, Arena* arena) {
+        if (this->key.change_if_greater(to.key, arena)) {
+            this->value.change(to.value, arena);
+            return true;
+        }
+        return false;
+    }
+
+    static const char* name() { return "max_by"; }
+};
+
+template <typename VT, typename KT>
+struct AggregateFunctionMinByData : public AggregateFunctionMinMaxByBaseData<VT, KT> {
+    using Self = AggregateFunctionMinByData;
+    bool change_if_better(const IColumn& value_column, const IColumn& key_column, size_t row_num,
+                          Arena* arena) {
+        if (this->key.change_if_less(key_column, row_num, arena)) {
+            this->value.change(value_column, row_num, arena);
+            return true;
+        }
+        return false;
+    }
+
+    bool change_if_better(const Self& to, Arena* arena) {
+        if (this->key.change_if_less(to.key, arena)) {
+            this->value.change(to.value, arena);
+            return true;
+        }
+        return false;
+    }
+
+    static const char* name() { return "min_by"; }
+};
+
+template <typename Data, bool AllocatesMemoryInArena>
+class AggregateFunctionsMinMaxBy final
+        : public IAggregateFunctionDataHelper<
+                  Data, AggregateFunctionsMinMaxBy<Data, AllocatesMemoryInArena>> {
+private:
+    DataTypePtr& value_type;
+    DataTypePtr& key_type;
+
+public:
+    AggregateFunctionsMinMaxBy(const DataTypePtr& value_type_, const DataTypePtr& key_type_)
+            : IAggregateFunctionDataHelper<
+                      Data, AggregateFunctionsMinMaxBy<Data, AllocatesMemoryInArena>>(
+                      {value_type_, key_type_}, {}),
+              value_type(this->argument_types[0]),
+              key_type(this->argument_types[1]) {
+        if (StringRef(Data::name()) == StringRef("min_by") ||
+            StringRef(Data::name()) == StringRef("max_by")) {
+            CHECK(key_type_->is_comparable());
+        }
+    }
+
+    String get_name() const override { return Data::name(); }
+
+    DataTypePtr get_return_type() const override { return value_type; }
+
+    void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
+             Arena* arena) const override {
+        this->data(place).change_if_better(*columns[0], *columns[1], row_num, arena);
+    }
+
+    void reset(AggregateDataPtr place) const override { this->data(place).reset(); }
+
+    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+               Arena* arena) const override {
+        this->data(place).change_if_better(this->data(rhs), arena);
+    }
+
+    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
+        this->data(place).write(buf);
+    }
+
+    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+                     Arena*) const override {
+        this->data(place).read(buf);
+    }
+
+    bool allocates_memory_in_arena() const override { return AllocatesMemoryInArena; }
+
+    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
+        this->data(place).insert_result_into(to);
+    }
+};
+
+AggregateFunctionPtr create_aggregate_function_max_by(const std::string& name,
+                                                      const DataTypes& argument_types,
+                                                      const Array& parameters,
+                                                      const bool result_is_nullable);
+
+AggregateFunctionPtr create_aggregate_function_min_by(const std::string& name,
+                                                      const DataTypes& argument_types,
+                                                      const Array& parameters,
+                                                      const bool result_is_nullable);
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index fcf333c1bd..6315fd6600 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -29,6 +29,7 @@ class AggregateFunctionSimpleFactory;
 void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory);
 void register_aggregate_function_combinator_null(AggregateFunctionSimpleFactory& factory);
 void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& factory);
+void register_aggregate_function_min_max_by(AggregateFunctionSimpleFactory& factory);
 void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory);
 void register_aggregate_function_count(AggregateFunctionSimpleFactory& factory);
 void register_aggregate_function_HLL_union_agg(AggregateFunctionSimpleFactory& factory);
@@ -51,6 +52,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
     std::call_once(oc, [&]() {
         register_aggregate_function_sum(instance);
         register_aggregate_function_minmax(instance);
+        register_aggregate_function_min_max_by(instance);
         register_aggregate_function_avg(instance);
         register_aggregate_function_count(instance);
         register_aggregate_function_uniq(instance);
diff --git a/be/test/CMakeLists.txt b/be/test/CMakeLists.txt
index 4c784c1ff9..a1b0cfddbf 100644
--- a/be/test/CMakeLists.txt
+++ b/be/test/CMakeLists.txt
@@ -331,6 +331,7 @@ set(VEC_TEST_FILES
     vec/aggregate_functions/agg_test.cpp
     vec/aggregate_functions/agg_min_max_test.cpp
     vec/aggregate_functions/vec_window_funnel_test.cpp
+    vec/aggregate_functions/agg_min_max_by_test.cpp
     vec/core/block_test.cpp
     vec/core/column_array_test.cpp
     vec/core/column_complex_test.cpp
diff --git a/be/test/vec/aggregate_functions/agg_min_max_by_test.cpp b/be/test/vec/aggregate_functions/agg_min_max_by_test.cpp
new file mode 100644
index 0000000000..a25af83b18
--- /dev/null
+++ b/be/test/vec/aggregate_functions/agg_min_max_by_test.cpp
@@ -0,0 +1,102 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <string>
+
+#include "gtest/gtest.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/aggregate_function_min_max_by.h"
+#include "vec/columns/column_vector.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/data_types/data_type_string.h"
+
+const int agg_test_batch_size = 4096;
+
+namespace doris::vectorized {
+// declare function
+void register_aggregate_function_min_max_by(AggregateFunctionSimpleFactory& factory);
+
+class AggMinMaxByTest : public ::testing::TestWithParam<std::string> {};
+
+TEST_P(AggMinMaxByTest, min_max_by_test) {
+    std::string min_max_by_type = GetParam();
+    // Prepare test data.
+    auto column_vector_value = ColumnInt32::create();
+    auto column_vector_key_int32 = ColumnInt32::create();
+    auto column_vector_key_str = ColumnString::create();
+    auto max_pair = std::make_pair<std::string, int32_t>("foo_0", 0);
+    auto min_pair = max_pair;
+    for (int i = 0; i < agg_test_batch_size; i++) {
+        column_vector_value->insert(cast_to_nearest_field_type(i));
+        column_vector_key_int32->insert(cast_to_nearest_field_type(agg_test_batch_size - i));
+        std::string str_val = fmt::format("foo_{}", i);
+        if (max_pair.first < str_val) {
+            max_pair.first = str_val;
+            max_pair.second = i;
+        }
+        if (min_pair.first > str_val) {
+            min_pair.first = str_val;
+            min_pair.second = i;
+        }
+        column_vector_key_str->insert(cast_to_nearest_field_type(str_val));
+    }
+
+    // Prepare test function and parameters.
+    AggregateFunctionSimpleFactory factory;
+    register_aggregate_function_min_max_by(factory);
+
+    // Test on 2 kind of key types (int32, string).
+    for (int i = 0; i < 2; i++) {
+        DataTypes data_types = {std::make_shared<DataTypeInt32>(),
+                                i == 0 ? (DataTypePtr)std::make_shared<DataTypeInt32>()
+                                       : (DataTypePtr)std::make_shared<DataTypeString>()};
+        Array array;
+        auto agg_function = factory.get(min_max_by_type, data_types, array);
+        std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]);
+        AggregateDataPtr place = memory.get();
+        agg_function->create(place);
+
+        // Do aggregation.
+        const IColumn* columns[2] = {column_vector_value.get(),
+                                     i == 0 ? (IColumn*)column_vector_key_int32.get()
+                                            : (IColumn*)column_vector_key_str.get()};
+        for (int j = 0; j < agg_test_batch_size; j++) {
+            agg_function->add(place, columns, j, nullptr);
+        }
+
+        // Check result.
+        ColumnInt32 ans;
+        agg_function->insert_result_into(place, ans);
+        if (i == 0) {
+            // Key type is int32.
+            ASSERT_EQ(min_max_by_type == "max_by" ? 0 : agg_test_batch_size - 1,
+                      ans.get_element(0));
+        } else {
+            // Key type is string.
+            ASSERT_EQ(min_max_by_type == "max_by" ? max_pair.second : min_pair.second,
+                      ans.get_element(0));
+        }
+        agg_function->destroy(place);
+    }
+}
+
+INSTANTIATE_TEST_SUITE_P(Params, AggMinMaxByTest,
+                         ::testing::ValuesIn(std::vector<std::string> {"min_by", "max_by"}));
+} // namespace doris::vectorized
diff --git a/docs/en/sql-reference/sql-functions/aggregate-functions/max_by.md b/docs/en/sql-reference/sql-functions/aggregate-functions/max_by.md
new file mode 100644
index 0000000000..27819f26d4
--- /dev/null
+++ b/docs/en/sql-reference/sql-functions/aggregate-functions/max_by.md
@@ -0,0 +1,56 @@
+---
+{
+    "title": "MAX_BY",
+    "language": "en"
+}
+---
+
+<!-- 
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+# MAX_BY
+## description
+### Syntax
+
+`MAX_BY(expr1, expr2)`
+
+
+Returns the value of an expr1 associated with the maximum value of expr2 in a group.
+
+## example
+```
+MySQL > select * from tbl;
++------+------+------+------+
+| k1   | k2   | k3   | k4   |
++------+------+------+------+
+|    0 | 3    | 2    |  100 |
+|    1 | 2    | 3    |    4 |
+|    4 | 3    | 2    |    1 |
+|    3 | 4    | 2    |    1 |
++------+------+------+------+
+
+MySQL > select max_by(k1, k4) from tbl;
++--------------------+
+| max_by(`k1`, `k4`) |
++--------------------+
+|                  0 |
++--------------------+ 
+```
+## keyword
+MAX_BY
diff --git a/docs/en/sql-reference/sql-functions/aggregate-functions/min_by.md b/docs/en/sql-reference/sql-functions/aggregate-functions/min_by.md
new file mode 100644
index 0000000000..98a3478cb4
--- /dev/null
+++ b/docs/en/sql-reference/sql-functions/aggregate-functions/min_by.md
@@ -0,0 +1,56 @@
+---
+{
+    "title": "MIN_BY",
+    "language": "en"
+}
+---
+
+<!-- 
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+# MIN_BY
+## description
+### Syntax
+
+`MIN_BY(expr1, expr2)`
+
+
+Returns the value of an expr1 associated with the minimum value of expr2 in a group.
+
+## example
+```
+MySQL > select * from tbl;
++------+------+------+------+
+| k1   | k2   | k3   | k4   |
++------+------+------+------+
+|    0 | 3    | 2    |  100 |
+|    1 | 2    | 3    |    4 |
+|    4 | 3    | 2    |    1 |
+|    3 | 4    | 2    |    1 |
++------+------+------+------+
+
+MySQL > select min_by(k1, k4) from tbl;
++--------------------+
+| min_by(`k1`, `k4`) |
++--------------------+
+|                  4 |
++--------------------+ 
+```
+## keyword
+MIN_BY
diff --git a/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/max_by.md b/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/max_by.md
new file mode 100644
index 0000000000..ce9d71da38
--- /dev/null
+++ b/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/max_by.md
@@ -0,0 +1,56 @@
+---
+{
+    "title": "MAX_BY",
+    "language": "zh-CN"
+}
+---
+
+<!-- 
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+# MAX_BY
+## description
+### Syntax
+
+`MAX_BY(expr1, expr2)`
+
+
+返回与 expr2 的最大值关联的 expr1 的值。
+
+## example
+```
+MySQL > select * from tbl;
++------+------+------+------+
+| k1   | k2   | k3   | k4   |
++------+------+------+------+
+|    0 | 3    | 2    |  100 |
+|    1 | 2    | 3    |    4 |
+|    4 | 3    | 2    |    1 |
+|    3 | 4    | 2    |    1 |
++------+------+------+------+
+
+MySQL > select max_by(k1, k4) from tbl;
++--------------------+
+| max_by(`k1`, `k4`) |
++--------------------+
+|                  0 |
++--------------------+ 
+```
+## keyword
+MAX_BY
diff --git a/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/min_by.md b/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/min_by.md
new file mode 100644
index 0000000000..59a20cf7f9
--- /dev/null
+++ b/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/min_by.md
@@ -0,0 +1,56 @@
+---
+{
+    "title": "MIN_BY",
+    "language": "zh-CN"
+}
+---
+
+<!-- 
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+# MIN_BY
+## description
+### Syntax
+
+`MIN_BY(expr1, expr2)`
+
+
+返回与 expr2 的最小值关联的 expr1 的值。
+
+## example
+```
+MySQL > select * from tbl;
++------+------+------+------+
+| k1   | k2   | k3   | k4   |
++------+------+------+------+
+|    0 | 3    | 2    |  100 |
+|    1 | 2    | 3    |    4 |
+|    4 | 3    | 2    |    1 |
+|    3 | 4    | 2    |    1 |
++------+------+------+------+
+
+MySQL > select min_by(k1, k4) from tbl;
++--------------------+
+| min_by(`k1`, `k4`) |
++--------------------+
+|                  4 |
++--------------------+ 
+```
+## keyword
+MIN_BY
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
index 10272b126e..1f8f67da66 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
@@ -455,10 +455,17 @@ public final class AggregateInfo extends AggregateInfoBase {
         for (int i = 0; i < getAggregateExprs().size(); ++i) {
             FunctionCallExpr inputExpr = getAggregateExprs().get(i);
             Preconditions.checkState(inputExpr.isAggregateFunction());
-            Expr aggExprParam =
-                    new SlotRef(inputDesc.getSlots().get(i + getGroupingExprs().size()));
+            List<Expr> paramExprs = new ArrayList<>();
+            // TODO(zhannngchen), change intermediate argument to a list, and remove this
+            // ad-hoc logic
+            if (inputExpr.fn.functionName().equals("max_by") ||
+                    inputExpr.fn.functionName().equals("min_by")) {
+                paramExprs.addAll(inputExpr.getFnParams().exprs());
+            } else {
+                paramExprs.add(new SlotRef(inputDesc.getSlots().get(i + getGroupingExprs().size())));
+            }
             FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall(
-                    inputExpr, Lists.newArrayList(aggExprParam));
+                    inputExpr, paramExprs);
             aggExpr.analyzeNoThrow(analyzer);
             aggExprs.add(aggExpr);
         }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
index ac8403843e..e95b1c66bb 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
@@ -1598,6 +1598,20 @@ public class FunctionSet<min_initIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionCo
                     minMaxSerializeOrFinalize, minMaxGetValue,
                     null, minMaxSerializeOrFinalize, true, true, false, true));
 
+            // vectorized
+            for (Type kt : Type.getSupportedTypes()) {
+                if (kt.isNull()) {
+                    continue;
+                }
+                addBuiltin(AggregateFunction.createBuiltin("max_by", Lists.newArrayList(t, kt), t, Type.VARCHAR,
+                        "", "", "", "", "", null, "",
+                        true, true, false, true));
+                addBuiltin(AggregateFunction.createBuiltin("min_by", Lists.newArrayList(t, kt), t, Type.VARCHAR,
+                        "", "", "", "", "", null, "",
+                        true, true, false, true));
+            }
+
+
             // NDV
             // ndv return string
             addBuiltin(AggregateFunction.createBuiltin("ndv", Lists.newArrayList(t), Type.BIGINT, Type.VARCHAR,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org