You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by mo...@apache.org on 2022/04/20 07:08:17 UTC
[incubator-doris] 01/03: [feature](vectorized)(function) Support min_by/max_by function. (#8623)
This is an automated email from the ASF dual-hosted git repository.
morningman pushed a commit to branch dev-1.0.1
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
commit 5ce3ae53e743f188535a99dfa98236ab0b7a5e5f
Author: zhannngchen <48...@users.noreply.github.com>
AuthorDate: Wed Apr 20 14:46:19 2022 +0800
[feature](vectorized)(function) Support min_by/max_by function. (#8623)
Support min_by/max_by on vectorized engine.
---
be/src/vec/CMakeLists.txt | 1 +
.../aggregate_function_min_max_by.cpp | 116 +++++++++++++++
.../aggregate_function_min_max_by.h | 161 +++++++++++++++++++++
.../aggregate_function_simple_factory.cpp | 2 +
.../aggregate_functions/agg_min_max_by_test.cpp | 102 +++++++++++++
.../sql-functions/aggregate-functions/max_by.md | 56 +++++++
.../sql-functions/aggregate-functions/min_by.md | 56 +++++++
.../sql-functions/aggregate-functions/max_by.md | 56 +++++++
.../sql-functions/aggregate-functions/min_by.md | 56 +++++++
.../org/apache/doris/analysis/AggregateInfo.java | 13 +-
.../java/org/apache/doris/catalog/FunctionSet.java | 14 ++
11 files changed, 630 insertions(+), 3 deletions(-)
diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt
index 243fd1869e..2d30b33f50 100644
--- a/be/src/vec/CMakeLists.txt
+++ b/be/src/vec/CMakeLists.txt
@@ -26,6 +26,7 @@ set(VEC_FILES
aggregate_functions/aggregate_function_distinct.cpp
aggregate_functions/aggregate_function_sum.cpp
aggregate_functions/aggregate_function_min_max.cpp
+ aggregate_functions/aggregate_function_min_max_by.cpp
aggregate_functions/aggregate_function_null.cpp
aggregate_functions/aggregate_function_uniq.cpp
aggregate_functions/aggregate_function_hll_union_agg.cpp
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
new file mode 100644
index 0000000000..765e1f05ee
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.cpp
@@ -0,0 +1,116 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "vec/aggregate_functions/aggregate_function_min_max.h"
+#include "vec/aggregate_functions/aggregate_function_min_max_by.h"
+
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/factory_helpers.h"
+#include "vec/aggregate_functions/helpers.h"
+
+namespace doris::vectorized {
+
+/// min_by, max_by
+template <template <typename, bool> class AggregateFunctionTemplate,
+ template <typename, typename> class Data, typename VT>
+static IAggregateFunction* create_aggregate_function_min_max_by_impl(
+ const DataTypes& argument_types) {
+ const DataTypePtr& value_arg_type = argument_types[0];
+ const DataTypePtr& key_arg_type = argument_types[1];
+
+ WhichDataType which(key_arg_type);
+#define DISPATCH(TYPE) \
+ if (which.idx == TypeIndex::TYPE) \
+ return new AggregateFunctionTemplate<Data<VT, SingleValueDataFixed<TYPE>>, false>( \
+ value_arg_type, key_arg_type);
+ FOR_NUMERIC_TYPES(DISPATCH)
+#undef DISPATCH
+ if (which.idx == TypeIndex::String) {
+ return new AggregateFunctionTemplate<Data<VT, SingleValueDataString>, false>(value_arg_type,
+ key_arg_type);
+ }
+ if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::Date) {
+ return new AggregateFunctionTemplate<Data<VT, SingleValueDataFixed<Int64>>, false>(
+ value_arg_type, key_arg_type);
+ }
+ if (which.idx == TypeIndex::Decimal128) {
+ return new AggregateFunctionTemplate<Data<VT, SingleValueDataFixed<DecimalV2Value>>, false>(
+ value_arg_type, key_arg_type);
+ }
+ return nullptr;
+}
+
+/// min_by, max_by
+template <template <typename, bool> class AggregateFunctionTemplate,
+ template <typename, typename> class Data>
+static IAggregateFunction* create_aggregate_function_min_max_by(const String& name,
+ const DataTypes& argument_types,
+ const Array& parameters) {
+ assert_no_parameters(name, parameters);
+ assert_binary(name, argument_types);
+
+ const DataTypePtr& value_arg_type = argument_types[0];
+
+ WhichDataType which(value_arg_type);
+#define DISPATCH(TYPE) \
+ if (which.idx == TypeIndex::TYPE) \
+ return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data, \
+ SingleValueDataFixed<TYPE>>( \
+ argument_types);
+ FOR_NUMERIC_TYPES(DISPATCH)
+#undef DISPATCH
+ if (which.idx == TypeIndex::String) {
+ return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data,
+ SingleValueDataString>(argument_types);
+ }
+ if (which.idx == TypeIndex::DateTime || which.idx == TypeIndex::Date) {
+ return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data,
+ SingleValueDataFixed<Int64>>(
+ argument_types);
+ }
+ if (which.idx == TypeIndex::Decimal128) {
+ return create_aggregate_function_min_max_by_impl<AggregateFunctionTemplate, Data,
+ SingleValueDataFixed<DecimalV2Value>>(
+ argument_types);
+ }
+ return nullptr;
+}
+
+AggregateFunctionPtr create_aggregate_function_max_by(const std::string& name,
+ const DataTypes& argument_types,
+ const Array& parameters,
+ const bool result_is_nullable) {
+ return AggregateFunctionPtr(create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy,
+ AggregateFunctionMaxByData>(
+ name, argument_types, parameters));
+}
+
+AggregateFunctionPtr create_aggregate_function_min_by(const std::string& name,
+ const DataTypes& argument_types,
+ const Array& parameters,
+ const bool result_is_nullable) {
+ return AggregateFunctionPtr(create_aggregate_function_min_max_by<AggregateFunctionsMinMaxBy,
+ AggregateFunctionMinByData>(
+ name, argument_types, parameters));
+}
+
+void register_aggregate_function_min_max_by(AggregateFunctionSimpleFactory& factory) {
+ factory.register_function("max_by", create_aggregate_function_max_by);
+ factory.register_function("min_by", create_aggregate_function_min_by);
+}
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
new file mode 100644
index 0000000000..5fb061c3e9
--- /dev/null
+++ b/be/src/vec/aggregate_functions/aggregate_function_min_max_by.h
@@ -0,0 +1,161 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include "common/logging.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/columns/column_decimal.h"
+#include "vec/columns/column_vector.h"
+#include "vec/common/assert_cast.h"
+#include "vec/io/io_helper.h"
+
+namespace doris::vectorized {
+template <typename VT, typename KT>
+struct AggregateFunctionMinMaxByBaseData {
+protected:
+ VT value;
+ KT key;
+
+public:
+ void insert_result_into(IColumn& to) const { value.insert_result_into(to); }
+
+ void reset() {
+ value.reset();
+ key.reset();
+ }
+ void write(BufferWritable& buf) const {
+ value.write(buf);
+ key.write(buf);
+ }
+
+ void read(BufferReadable& buf) {
+ value.read(buf);
+ key.read(buf);
+ }
+};
+
+template <typename VT, typename KT>
+struct AggregateFunctionMaxByData : public AggregateFunctionMinMaxByBaseData<VT, KT> {
+ using Self = AggregateFunctionMaxByData;
+ bool change_if_better(const IColumn& value_column, const IColumn& key_column, size_t row_num,
+ Arena* arena) {
+ if (this->key.change_if_greater(key_column, row_num, arena)) {
+ this->value.change(value_column, row_num, arena);
+ return true;
+ }
+ return false;
+ }
+
+ bool change_if_better(const Self& to, Arena* arena) {
+ if (this->key.change_if_greater(to.key, arena)) {
+ this->value.change(to.value, arena);
+ return true;
+ }
+ return false;
+ }
+
+ static const char* name() { return "max_by"; }
+};
+
+template <typename VT, typename KT>
+struct AggregateFunctionMinByData : public AggregateFunctionMinMaxByBaseData<VT, KT> {
+ using Self = AggregateFunctionMinByData;
+ bool change_if_better(const IColumn& value_column, const IColumn& key_column, size_t row_num,
+ Arena* arena) {
+ if (this->key.change_if_less(key_column, row_num, arena)) {
+ this->value.change(value_column, row_num, arena);
+ return true;
+ }
+ return false;
+ }
+
+ bool change_if_better(const Self& to, Arena* arena) {
+ if (this->key.change_if_less(to.key, arena)) {
+ this->value.change(to.value, arena);
+ return true;
+ }
+ return false;
+ }
+
+ static const char* name() { return "min_by"; }
+};
+
+template <typename Data, bool AllocatesMemoryInArena>
+class AggregateFunctionsMinMaxBy final
+ : public IAggregateFunctionDataHelper<
+ Data, AggregateFunctionsMinMaxBy<Data, AllocatesMemoryInArena>> {
+private:
+ DataTypePtr& value_type;
+ DataTypePtr& key_type;
+
+public:
+ AggregateFunctionsMinMaxBy(const DataTypePtr& value_type_, const DataTypePtr& key_type_)
+ : IAggregateFunctionDataHelper<
+ Data, AggregateFunctionsMinMaxBy<Data, AllocatesMemoryInArena>>(
+ {value_type_, key_type_}, {}),
+ value_type(this->argument_types[0]),
+ key_type(this->argument_types[1]) {
+ if (StringRef(Data::name()) == StringRef("min_by") ||
+ StringRef(Data::name()) == StringRef("max_by")) {
+ CHECK(key_type_->is_comparable());
+ }
+ }
+
+ String get_name() const override { return Data::name(); }
+
+ DataTypePtr get_return_type() const override { return value_type; }
+
+ void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
+ Arena* arena) const override {
+ this->data(place).change_if_better(*columns[0], *columns[1], row_num, arena);
+ }
+
+ void reset(AggregateDataPtr place) const override { this->data(place).reset(); }
+
+ void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
+ Arena* arena) const override {
+ this->data(place).change_if_better(this->data(rhs), arena);
+ }
+
+ void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
+ this->data(place).write(buf);
+ }
+
+ void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
+ Arena*) const override {
+ this->data(place).read(buf);
+ }
+
+ bool allocates_memory_in_arena() const override { return AllocatesMemoryInArena; }
+
+ void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
+ this->data(place).insert_result_into(to);
+ }
+};
+
+AggregateFunctionPtr create_aggregate_function_max_by(const std::string& name,
+ const DataTypes& argument_types,
+ const Array& parameters,
+ const bool result_is_nullable);
+
+AggregateFunctionPtr create_aggregate_function_min_by(const std::string& name,
+ const DataTypes& argument_types,
+ const Array& parameters,
+ const bool result_is_nullable);
+
+} // namespace doris::vectorized
diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
index fcf333c1bd..6315fd6600 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
+++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp
@@ -29,6 +29,7 @@ class AggregateFunctionSimpleFactory;
void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_combinator_null(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& factory);
+void register_aggregate_function_min_max_by(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_avg(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_count(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_HLL_union_agg(AggregateFunctionSimpleFactory& factory);
@@ -51,6 +52,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
std::call_once(oc, [&]() {
register_aggregate_function_sum(instance);
register_aggregate_function_minmax(instance);
+ register_aggregate_function_min_max_by(instance);
register_aggregate_function_avg(instance);
register_aggregate_function_count(instance);
register_aggregate_function_uniq(instance);
diff --git a/be/test/vec/aggregate_functions/agg_min_max_by_test.cpp b/be/test/vec/aggregate_functions/agg_min_max_by_test.cpp
new file mode 100644
index 0000000000..a25af83b18
--- /dev/null
+++ b/be/test/vec/aggregate_functions/agg_min_max_by_test.cpp
@@ -0,0 +1,102 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+#include <string>
+
+#include "gtest/gtest.h"
+#include "vec/aggregate_functions/aggregate_function.h"
+#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
+#include "vec/aggregate_functions/aggregate_function_min_max_by.h"
+#include "vec/columns/column_vector.h"
+#include "vec/data_types/data_type.h"
+#include "vec/data_types/data_type_number.h"
+#include "vec/data_types/data_type_string.h"
+
+const int agg_test_batch_size = 4096;
+
+namespace doris::vectorized {
+// declare function
+void register_aggregate_function_min_max_by(AggregateFunctionSimpleFactory& factory);
+
+class AggMinMaxByTest : public ::testing::TestWithParam<std::string> {};
+
+TEST_P(AggMinMaxByTest, min_max_by_test) {
+ std::string min_max_by_type = GetParam();
+ // Prepare test data.
+ auto column_vector_value = ColumnInt32::create();
+ auto column_vector_key_int32 = ColumnInt32::create();
+ auto column_vector_key_str = ColumnString::create();
+ auto max_pair = std::make_pair<std::string, int32_t>("foo_0", 0);
+ auto min_pair = max_pair;
+ for (int i = 0; i < agg_test_batch_size; i++) {
+ column_vector_value->insert(cast_to_nearest_field_type(i));
+ column_vector_key_int32->insert(cast_to_nearest_field_type(agg_test_batch_size - i));
+ std::string str_val = fmt::format("foo_{}", i);
+ if (max_pair.first < str_val) {
+ max_pair.first = str_val;
+ max_pair.second = i;
+ }
+ if (min_pair.first > str_val) {
+ min_pair.first = str_val;
+ min_pair.second = i;
+ }
+ column_vector_key_str->insert(cast_to_nearest_field_type(str_val));
+ }
+
+ // Prepare test function and parameters.
+ AggregateFunctionSimpleFactory factory;
+ register_aggregate_function_min_max_by(factory);
+
+ // Test on 2 kind of key types (int32, string).
+ for (int i = 0; i < 2; i++) {
+ DataTypes data_types = {std::make_shared<DataTypeInt32>(),
+ i == 0 ? (DataTypePtr)std::make_shared<DataTypeInt32>()
+ : (DataTypePtr)std::make_shared<DataTypeString>()};
+ Array array;
+ auto agg_function = factory.get(min_max_by_type, data_types, array);
+ std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]);
+ AggregateDataPtr place = memory.get();
+ agg_function->create(place);
+
+ // Do aggregation.
+ const IColumn* columns[2] = {column_vector_value.get(),
+ i == 0 ? (IColumn*)column_vector_key_int32.get()
+ : (IColumn*)column_vector_key_str.get()};
+ for (int j = 0; j < agg_test_batch_size; j++) {
+ agg_function->add(place, columns, j, nullptr);
+ }
+
+ // Check result.
+ ColumnInt32 ans;
+ agg_function->insert_result_into(place, ans);
+ if (i == 0) {
+ // Key type is int32.
+ ASSERT_EQ(min_max_by_type == "max_by" ? 0 : agg_test_batch_size - 1,
+ ans.get_element(0));
+ } else {
+ // Key type is string.
+ ASSERT_EQ(min_max_by_type == "max_by" ? max_pair.second : min_pair.second,
+ ans.get_element(0));
+ }
+ agg_function->destroy(place);
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(Params, AggMinMaxByTest,
+ ::testing::ValuesIn(std::vector<std::string> {"min_by", "max_by"}));
+} // namespace doris::vectorized
diff --git a/docs/en/sql-reference/sql-functions/aggregate-functions/max_by.md b/docs/en/sql-reference/sql-functions/aggregate-functions/max_by.md
new file mode 100644
index 0000000000..27819f26d4
--- /dev/null
+++ b/docs/en/sql-reference/sql-functions/aggregate-functions/max_by.md
@@ -0,0 +1,56 @@
+---
+{
+ "title": "MAX_BY",
+ "language": "en"
+}
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+# MAX_BY
+## description
+### Syntax
+
+`MAX_BY(expr1, expr2)`
+
+
+Returns the value of an expr1 associated with the maximum value of expr2 in a group.
+
+## example
+```
+MySQL > select * from tbl;
++------+------+------+------+
+| k1 | k2 | k3 | k4 |
++------+------+------+------+
+| 0 | 3 | 2 | 100 |
+| 1 | 2 | 3 | 4 |
+| 4 | 3 | 2 | 1 |
+| 3 | 4 | 2 | 1 |
++------+------+------+------+
+
+MySQL > select max_by(k1, k4) from tbl;
++--------------------+
+| max_by(`k1`, `k4`) |
++--------------------+
+| 0 |
++--------------------+
+```
+## keyword
+MAX_BY
diff --git a/docs/en/sql-reference/sql-functions/aggregate-functions/min_by.md b/docs/en/sql-reference/sql-functions/aggregate-functions/min_by.md
new file mode 100644
index 0000000000..98a3478cb4
--- /dev/null
+++ b/docs/en/sql-reference/sql-functions/aggregate-functions/min_by.md
@@ -0,0 +1,56 @@
+---
+{
+ "title": "MIN_BY",
+ "language": "en"
+}
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+# MIN_BY
+## description
+### Syntax
+
+`MIN_BY(expr1, expr2)`
+
+
+Returns the value of an expr1 associated with the minimum value of expr2 in a group.
+
+## example
+```
+MySQL > select * from tbl;
++------+------+------+------+
+| k1 | k2 | k3 | k4 |
++------+------+------+------+
+| 0 | 3 | 2 | 100 |
+| 1 | 2 | 3 | 4 |
+| 4 | 3 | 2 | 1 |
+| 3 | 4 | 2 | 1 |
++------+------+------+------+
+
+MySQL > select min_by(k1, k4) from tbl;
++--------------------+
+| min_by(`k1`, `k4`) |
++--------------------+
+| 4 |
++--------------------+
+```
+## keyword
+MIN_BY
diff --git a/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/max_by.md b/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/max_by.md
new file mode 100644
index 0000000000..ce9d71da38
--- /dev/null
+++ b/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/max_by.md
@@ -0,0 +1,56 @@
+---
+{
+ "title": "MAX_BY",
+ "language": "zh-CN"
+}
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+# MAX_BY
+## description
+### Syntax
+
+`MAX_BY(expr1, expr2)`
+
+
+返回与 expr2 的最大值关联的 expr1 的值。
+
+## example
+```
+MySQL > select * from tbl;
++------+------+------+------+
+| k1 | k2 | k3 | k4 |
++------+------+------+------+
+| 0 | 3 | 2 | 100 |
+| 1 | 2 | 3 | 4 |
+| 4 | 3 | 2 | 1 |
+| 3 | 4 | 2 | 1 |
++------+------+------+------+
+
+MySQL > select max_by(k1, k4) from tbl;
++--------------------+
+| max_by(`k1`, `k4`) |
++--------------------+
+| 0 |
++--------------------+
+```
+## keyword
+MAX_BY
diff --git a/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/min_by.md b/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/min_by.md
new file mode 100644
index 0000000000..59a20cf7f9
--- /dev/null
+++ b/docs/zh-CN/sql-reference/sql-functions/aggregate-functions/min_by.md
@@ -0,0 +1,56 @@
+---
+{
+ "title": "MIN_BY",
+ "language": "zh-CN"
+}
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+# MIN_BY
+## description
+### Syntax
+
+`MIN_BY(expr1, expr2)`
+
+
+返回与 expr2 的最小值关联的 expr1 的值。
+
+## example
+```
+MySQL > select * from tbl;
++------+------+------+------+
+| k1 | k2 | k3 | k4 |
++------+------+------+------+
+| 0 | 3 | 2 | 100 |
+| 1 | 2 | 3 | 4 |
+| 4 | 3 | 2 | 1 |
+| 3 | 4 | 2 | 1 |
++------+------+------+------+
+
+MySQL > select min_by(k1, k4) from tbl;
++--------------------+
+| min_by(`k1`, `k4`) |
++--------------------+
+| 4 |
++--------------------+
+```
+## keyword
+MIN_BY
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
index 3e14d8964f..f02d15c7b9 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/AggregateInfo.java
@@ -452,10 +452,17 @@ public final class AggregateInfo extends AggregateInfoBase {
for (int i = 0; i < getAggregateExprs().size(); ++i) {
FunctionCallExpr inputExpr = getAggregateExprs().get(i);
Preconditions.checkState(inputExpr.isAggregateFunction());
- Expr aggExprParam =
- new SlotRef(inputDesc.getSlots().get(i + getGroupingExprs().size()));
+ List<Expr> paramExprs = new ArrayList<>();
+ // TODO(zhannngchen), change intermediate argument to a list, and remove this
+ // ad-hoc logic
+ if (inputExpr.fn.functionName().equals("max_by") ||
+ inputExpr.fn.functionName().equals("min_by")) {
+ paramExprs.addAll(inputExpr.getFnParams().exprs());
+ } else {
+ paramExprs.add(new SlotRef(inputDesc.getSlots().get(i + getGroupingExprs().size())));
+ }
FunctionCallExpr aggExpr = FunctionCallExpr.createMergeAggCall(
- inputExpr, Lists.newArrayList(aggExprParam));
+ inputExpr, paramExprs);
aggExpr.analyzeNoThrow(analyzer);
aggExprs.add(aggExpr);
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
index 6f02521319..6487f8f9c1 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java
@@ -1589,6 +1589,20 @@ public class FunctionSet<min_initIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionCo
minMaxSerializeOrFinalize, minMaxGetValue,
null, minMaxSerializeOrFinalize, true, true, false, true));
+ // vectorized
+ for (Type kt : Type.getSupportedTypes()) {
+ if (kt.isNull()) {
+ continue;
+ }
+ addBuiltin(AggregateFunction.createBuiltin("max_by", Lists.newArrayList(t, kt), t, Type.VARCHAR,
+ "", "", "", "", "", null, "",
+ true, true, false, true));
+ addBuiltin(AggregateFunction.createBuiltin("min_by", Lists.newArrayList(t, kt), t, Type.VARCHAR,
+ "", "", "", "", "", null, "",
+ true, true, false, true));
+ }
+
+
// NDV
// ndv return string
addBuiltin(AggregateFunction.createBuiltin("ndv", Lists.newArrayList(t), Type.BIGINT, Type.VARCHAR,
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org