You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by si...@apache.org on 2023/05/09 07:32:07 UTC
[pinot] branch master updated: Support for ARG_MIN and ARG_MAX Functions (#10636)
This is an automated email from the ASF dual-hosted git repository.
siddteotia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 7a673fd604 Support for ARG_MIN and ARG_MAX Functions (#10636)
7a673fd604 is described below
commit 7a673fd6044396f79f1c4ae4b5d28bba12dc7b58
Author: Jia Guo <ji...@linkedin.com>
AuthorDate: Tue May 9 00:31:59 2023 -0700
Support for ARG_MIN and ARG_MAX Functions (#10636)
* Add ArgMinMax aggregation function
* Add more test cases, bug fix
* Address refactoring/doc related comments
* Address doc related comments
* Added more test cases
* Work around for DataBlock not able to ser/de empty array
* DataBlock bug reproduced
* Enhanced empty result return for group by queries, improved test cases
* Trigger Test
* Use another value filling scheme for multi-row result
* Removing Pinot prefix from parent and child aggregation function names
* Trigger Test
* Removing aggregation prefix from parent and child aggregation function names. Add more test cases. Refine error message.
* Trigger Test
---
.../broker/broker/helix/BaseBrokerStarter.java | 3 +
.../sql/parsers/rewriter/ArgMinMaxRewriter.java | 192 ++++++
.../sql/parsers/rewriter/QueryRewriterFactory.java | 2 +-
.../parsers/rewriter/ArgMinMaxRewriterTest.java | 76 +++
.../apache/pinot/core/common/ObjectSerDeUtils.java | 40 +-
.../function/AggregationFunctionFactory.java | 12 +
.../function/ChildAggregationFunction.java | 160 +++++
.../ChildArgMinMaxAggregationFunction.java | 39 ++
.../function/ParentAggregationFunction.java | 71 +++
.../ParentArgMinMaxAggregationFunction.java | 432 ++++++++++++++
.../groupby/DummyAggregationResultHolder.java | 55 ++
.../groupby/DummyGroupByResultHolder.java | 56 ++
.../ParentAggregationFunctionResultObject.java | 42 ++
.../argminmax/ArgMinMaxMeasuringValSetWrapper.java | 77 +++
.../utils/argminmax/ArgMinMaxObject.java | 353 +++++++++++
.../ArgMinMaxProjectionValSetWrapper.java | 70 +++
.../utils/argminmax/ArgMinMaxWrapperValSet.java | 106 ++++
.../query/reduce/AggregationDataTableReducer.java | 17 +-
.../core/query/reduce/GroupByDataTableReducer.java | 15 +-
.../rewriter/ParentAggregationResultRewriter.java | 222 +++++++
.../query/utils/rewriter/ResultRewriteUtils.java | 38 ++
.../core/query/utils/rewriter/ResultRewriter.java | 30 +
.../utils/rewriter/ResultRewriterFactory.java | 69 +++
.../core/query/utils/rewriter/RewriterResult.java | 41 ++
.../pinot/core/common/datablock/DataBlockTest.java | 46 ++
.../org/apache/pinot/queries/ArgMinMaxTest.java | 644 +++++++++++++++++++++
.../queries/ResultRewriterRegressionTest.java | 69 +++
.../pinot/segment/spi/AggregationFunctionType.java | 11 +-
.../apache/pinot/spi/utils/CommonConstants.java | 9 +
29 files changed, 2989 insertions(+), 8 deletions(-)
diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/BaseBrokerStarter.java b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/BaseBrokerStarter.java
index cb24a4d5d8..7c68519b3d 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/BaseBrokerStarter.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/BaseBrokerStarter.java
@@ -64,6 +64,7 @@ import org.apache.pinot.common.utils.config.TagNameUtils;
import org.apache.pinot.common.utils.helix.HelixHelper;
import org.apache.pinot.common.version.PinotVersion;
import org.apache.pinot.core.query.executor.sql.SqlQueryExecutor;
+import org.apache.pinot.core.query.utils.rewriter.ResultRewriterFactory;
import org.apache.pinot.core.transport.ListenerConfig;
import org.apache.pinot.core.transport.server.routing.stats.ServerRoutingStatsManager;
import org.apache.pinot.core.util.ListenerConfigUtil;
@@ -264,6 +265,8 @@ public abstract class BaseBrokerStarter implements ServiceStartable {
// Initialize QueryRewriterFactory
LOGGER.info("Initializing QueryRewriterFactory");
QueryRewriterFactory.init(_brokerConf.getProperty(Broker.CONFIG_OF_BROKER_QUERY_REWRITER_CLASS_NAMES));
+ LOGGER.info("Initializing ResultRewriterFactory");
+ ResultRewriterFactory.init(_brokerConf.getProperty(Broker.CONFIG_OF_BROKER_RESULT_REWRITER_CLASS_NAMES));
// Initialize FunctionRegistry before starting the broker request handler
FunctionRegistry.init();
boolean caseInsensitive =
diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/ArgMinMaxRewriter.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/ArgMinMaxRewriter.java
new file mode 100644
index 0000000000..9fd29be69c
--- /dev/null
+++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/ArgMinMaxRewriter.java
@@ -0,0 +1,192 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.sql.parsers.rewriter;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
+import org.apache.pinot.common.request.Expression;
+import org.apache.pinot.common.request.ExpressionType;
+import org.apache.pinot.common.request.Function;
+import org.apache.pinot.common.request.Literal;
+import org.apache.pinot.common.request.PinotQuery;
+import org.apache.pinot.spi.utils.CommonConstants;
+
+
+/**
+ * This rewriter rewrites ARG_MIN/ARG_MAX function, so that the functions with the same measuring expressions
+ * are consolidated and added as a single function with a list of projection expressions. For example, the query
+ * "SELECT ARG_MIN(col1, col2, col3), ARG_MIN(col1, col2, col4) FROM myTable" will be consolidated to a single
+ * function "PARENT_ARG_MIN(0, 2, col1, col2, col3, col4)". and added to the end of the selection list.
+ * While the original ARG_MIN(col1, col2, col3) and ARG_MIN(col1, col2, col4) will be rewritten to
+ * CHILD_ARG_MIN(0, col3, col1, col2, col3) and CHILD_ARG_MIN(0, col4, col1, col2, col4) respectively.
+ * The 2 new parameters for CHILD_ARG_MIN are the function ID (0) and the projection column (col1/col4),
+ * used as column key in the parent aggregation result, during result rewriting.
+ * PARENT_ARG_MIN(0, 2, col1, col2, col3, col4) means a parent aggregation function with function ID 0,
+ * 2 measuring columns (col1, col2), 2 projection columns (col3, col4). The function ID is unique for each
+ * consolidated function with the same function type and measuring columns.
+ * Later, the aggregation, result of the consolidated function will be filled into the corresponding
+ * columns of the original ARG_MIN/ARG_MAX. For more syntax details please refer to ParentAggregationFunction,
+ * ChildAggregationFunction and ChildAggregationResultRewriter.
+ */
+public class ArgMinMaxRewriter implements QueryRewriter {
+
+ private static final String ARG_MAX = "argmax";
+ private static final String ARG_MIN = "argmin";
+
+ private static final String ARG_MAX_PARENT =
+ CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX + ARG_MAX;
+ private static final String ARG_MIN_PARENT =
+ CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX + ARG_MIN;
+
+ @Override
+ public PinotQuery rewrite(PinotQuery pinotQuery) {
+ // This map stores the mapping from the list of measuring expressions to the set of projection expressions
+ HashMap<List<Expression>, Set<Expression>> argMinFunctionMap = new HashMap<>();
+ // This map stores the mapping from the list of measuring expressions to the function ID
+ HashMap<List<Expression>, Integer> argMinFunctionIDMap = new HashMap<>();
+
+ HashMap<List<Expression>, Set<Expression>> argMaxFunctionMap = new HashMap<>();
+ HashMap<List<Expression>, Integer> argMaxFunctionIDMap = new HashMap<>();
+
+ Iterator<Expression> iterator = pinotQuery.getSelectList().iterator();
+ while (iterator.hasNext()) {
+ boolean added = extractAndRewriteArgMinMaxFunctions(iterator.next(), argMaxFunctionMap, argMaxFunctionIDMap,
+ argMinFunctionMap, argMinFunctionIDMap);
+ // Remove the original function if it is not added, meaning it is a duplicate
+ if (!added) {
+ iterator.remove();
+ }
+ }
+
+ appendParentArgMinMaxFunctions(false, pinotQuery.getSelectList(), argMinFunctionMap, argMinFunctionIDMap);
+ appendParentArgMinMaxFunctions(true, pinotQuery.getSelectList(), argMaxFunctionMap, argMaxFunctionIDMap);
+
+ return pinotQuery;
+ }
+
+ /**
+ * This method appends the consolidated ARG_MIN/ARG_MAX functions to the end of the selection list.
+ * The consolidated function call will be in the following format:
+ * ARG_MAX(functionID, numMeasuringColumns, measuringColumn1, measuringColumn2, ... projectionColumn1,
+ * projectionColumn2, ...)
+ * where functionID is the ID of the consolidated function, numMeasuringColumns is the number of measuring
+ * columns, measuringColumn1, measuringColumn2, ... are the measuring columns, and projectionColumn1,
+ * projectionColumn2, ... are the projection columns.
+ * The number of projection columns is the same as the number of ARG_MIN/ARG_MAX functions with the same
+ * measuring columns.
+ */
+ private void appendParentArgMinMaxFunctions(boolean isMax, List<Expression> selectList,
+ HashMap<List<Expression>, Set<Expression>> argMinMaxFunctionMap,
+ HashMap<List<Expression>, Integer> argMinMaxFunctionIDMap) {
+ for (Map.Entry<List<Expression>, Set<Expression>> entry : argMinMaxFunctionMap.entrySet()) {
+ Literal functionID = new Literal();
+ functionID.setLongValue(argMinMaxFunctionIDMap.get(entry.getKey()));
+ Literal numMeasuringColumns = new Literal();
+ numMeasuringColumns.setLongValue(entry.getKey().size());
+
+ Function parentFunction = new Function(isMax ? ARG_MAX_PARENT : ARG_MIN_PARENT);
+ parentFunction.addToOperands(new Expression(ExpressionType.LITERAL).setLiteral(functionID));
+ parentFunction.addToOperands(new Expression(ExpressionType.LITERAL).setLiteral(numMeasuringColumns));
+ for (Expression expression : entry.getKey()) {
+ parentFunction.addToOperands(expression);
+ }
+ for (Expression expression : entry.getValue()) {
+ parentFunction.addToOperands(expression);
+ }
+ selectList.add(new Expression(ExpressionType.FUNCTION).setFunctionCall(parentFunction));
+ }
+ }
+
+ /**
+ * This method extracts the ARG_MIN/ARG_MAX functions from the given expression and rewrites the functions
+ * with the same measuring expressions to use the same function ID.
+ * @return true if the function is not duplicated, false otherwise.
+ */
+ private boolean extractAndRewriteArgMinMaxFunctions(Expression expression,
+ HashMap<List<Expression>, Set<Expression>> argMaxFunctionMap,
+ HashMap<List<Expression>, Integer> argMaxFunctionIDMap,
+ HashMap<List<Expression>, Set<Expression>> argMinFunctionMap,
+ HashMap<List<Expression>, Integer> argMinFunctionIDMap) {
+ Function function = expression.getFunctionCall();
+ if (function == null) {
+ return true;
+ }
+ String functionName = function.getOperator();
+ if (!(functionName.equals(ARG_MIN) || functionName.equals(ARG_MAX))) {
+ return true;
+ }
+ List<Expression> operands = function.getOperands();
+ if (operands.size() < 2) {
+ throw new IllegalStateException("Invalid number of arguments for " + functionName + ", argmin/argmax should "
+ + "have at least 2 arguments, got: " + operands.size());
+ }
+ List<Expression> argMinMaxMeasuringExpressions = new ArrayList<>();
+ for (int i = 0; i < operands.size() - 1; i++) {
+ argMinMaxMeasuringExpressions.add(operands.get(i));
+ }
+ Expression argMinMaxProjectionExpression = operands.get(operands.size() - 1);
+
+ if (functionName.equals(ARG_MIN)) {
+ return updateArgMinMaxFunctionMap(argMinMaxMeasuringExpressions, argMinMaxProjectionExpression, argMinFunctionMap,
+ argMinFunctionIDMap, function);
+ } else {
+ return updateArgMinMaxFunctionMap(argMinMaxMeasuringExpressions, argMinMaxProjectionExpression, argMaxFunctionMap,
+ argMaxFunctionIDMap, function);
+ }
+ }
+
+ /**
+ * This method rewrites the ARG_MIN/ARG_MAX function with the given measuring expressions to use the same
+ * function ID.
+ * @return true if the function is not duplicated, false otherwise.
+ */
+ private boolean updateArgMinMaxFunctionMap(List<Expression> argMinMaxMeasuringExpressions,
+ Expression argMinMaxProjectionExpression, HashMap<List<Expression>, Set<Expression>> argMinMaxFunctionMap,
+ HashMap<List<Expression>, Integer> argMinMaxFunctionIDMap, Function function) {
+ int size = argMinMaxFunctionIDMap.size();
+ int id = argMinMaxFunctionIDMap.computeIfAbsent(argMinMaxMeasuringExpressions, (k) -> size);
+
+ AtomicBoolean added = new AtomicBoolean(true);
+
+ argMinMaxFunctionMap.compute(argMinMaxMeasuringExpressions, (k, v) -> {
+ if (v == null) {
+ v = new HashSet<>();
+ }
+ added.set(v.add(argMinMaxProjectionExpression));
+ return v;
+ });
+
+ String operator = function.operator;
+ function.setOperator(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX + operator);
+
+ List<Expression> operands = function.getOperands();
+ operands.add(0, argMinMaxProjectionExpression);
+ Literal functionID = new Literal();
+ functionID.setLongValue(id);
+ operands.add(0, new Expression(ExpressionType.LITERAL).setLiteral(functionID));
+
+ return added.get();
+ }
+}
diff --git a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/QueryRewriterFactory.java b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/QueryRewriterFactory.java
index 4ae6c1bd93..ef36ee1080 100644
--- a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/QueryRewriterFactory.java
+++ b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/QueryRewriterFactory.java
@@ -33,7 +33,7 @@ public class QueryRewriterFactory {
private static final Logger LOGGER = LoggerFactory.getLogger(QueryRewriterFactory.class);
- static final List<String> DEFAULT_QUERY_REWRITERS_CLASS_NAMES =
+ public static final List<String> DEFAULT_QUERY_REWRITERS_CLASS_NAMES =
ImmutableList.of(CompileTimeFunctionsInvoker.class.getName(), SelectionsRewriter.class.getName(),
PredicateComparisonRewriter.class.getName(), OrdinalsUpdater.class.getName(),
AliasApplier.class.getName(), NonAggregationGroupByToDistinctQueryRewriter.class.getName());
diff --git a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/ArgMinMaxRewriterTest.java b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/ArgMinMaxRewriterTest.java
new file mode 100644
index 0000000000..479f609146
--- /dev/null
+++ b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/ArgMinMaxRewriterTest.java
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.sql.parsers.rewriter;
+
+import org.apache.pinot.sql.parsers.CalciteSqlParser;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+
+
+public class ArgMinMaxRewriterTest {
+ private static final QueryRewriter QUERY_REWRITER = new ArgMinMaxRewriter();
+
+ @Test
+ public void testQueryRewrite() {
+ testQueryRewrite("SELECT ARG_MIN(col1,col2), ARG_MIN(col1,col3) FROM myTable",
+ "SELECT CHILD_ARG_MIN(0,col2,col1,col2), "
+ + "CHILD_ARG_MIN(0,col3,col1,col3),"
+ + "PARENT_ARG_MIN(0,1,col1,col2,col3) FROM myTable");
+
+ testQueryRewrite("SELECT ARG_MIN(col1,col2), ARG_MIN(col1,col2) FROM myTable",
+ "SELECT CHILD_ARG_MIN(0,col2,col1,col2),"
+ + "PARENT_ARG_MIN(0,1,col1,col2) FROM myTable");
+
+ testQueryRewrite("SELECT ARG_MIN(col1,col2,col5), ARG_MIN(col1,col2,col6), ARG_MAX(col1,col2,col6) "
+ + "FROM myTable",
+ "SELECT CHILD_ARG_MIN(0,col5,col1,col2,col5), "
+ + "CHILD_ARG_MIN(0,col6,col1,col2,col6), "
+ + "CHILD_ARG_MAX(0,col6,col1,col2,col6),"
+ + "PARENT_ARG_MIN(0,2,col1,col2,col6,col5),"
+ + "PARENT_ARG_MAX(0,2,col1,col2,col6) FROM myTable");
+ }
+
+ @Test
+ public void testQueryRewriteWithOrderBy() {
+ testQueryRewrite("SELECT ARG_MIN(col1,col2,col5), ARG_MIN(col1,col3,col6),"
+ + "ARG_MIN(col3,col1,col6) FROM myTable GROUP BY col3 "
+ + "ORDER BY col3 DESC",
+ "SELECT CHILD_ARG_MIN(0,col5,col1,col2,col5), "
+ + "CHILD_ARG_MIN(1,col6,col1,col3,col6),"
+ + "CHILD_ARG_MIN(2,col6,col3,col1,col6),"
+ + "PARENT_ARG_MIN(1,2,col1,col3,col6),"
+ + "PARENT_ARG_MIN(0,2,col1,col2,col5),"
+ + "PARENT_ARG_MIN(2,2,col3,col1,col6)"
+ + "FROM myTable GROUP BY col3 ORDER BY col3 DESC");
+
+ testQueryRewrite("SELECT ARG_MIN(col1,col2,col5), ARG_MAX(col1,col2,col5) FROM myTable GROUP BY col3 "
+ + "ORDER BY ADD(co1, co3) DESC",
+ "SELECT CHILD_ARG_MIN(0,col5,col1,col2,col5),"
+ + "CHILD_ARG_MAX(0,col5,col1,col2,col5),"
+ + "PARENT_ARG_MIN(0,2,col1,col2,col5), "
+ + "PARENT_ARG_MAX(0,2,col1,col2,col5) "
+ + "FROM myTable GROUP BY col3 ORDER BY ADD(co1, co3) DESC");
+ }
+
+ private void testQueryRewrite(String original, String expected) {
+ assertEquals(QUERY_REWRITER.rewrite(CalciteSqlParser.compileToPinotQuery(original)),
+ CalciteSqlParser.compileToPinotQuery(expected));
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
index a01f02a5c9..9412014cef 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java
@@ -60,6 +60,7 @@ import java.util.Set;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.theta.Sketch;
import org.apache.pinot.common.CustomObject;
+import org.apache.pinot.core.query.aggregation.utils.argminmax.ArgMinMaxObject;
import org.apache.pinot.core.query.distinct.DistinctTable;
import org.apache.pinot.core.query.utils.idset.IdSet;
import org.apache.pinot.core.query.utils.idset.IdSets;
@@ -127,7 +128,8 @@ public class ObjectSerDeUtils {
StringLongPair(31),
CovarianceTuple(32),
VarianceTuple(33),
- PinotFourthMoment(34);
+ PinotFourthMoment(34),
+ ArgMinMaxObject(35);
private final int _value;
@@ -213,6 +215,8 @@ public class ObjectSerDeUtils {
return ObjectType.VarianceTuple;
} else if (value instanceof PinotFourthMoment) {
return ObjectType.PinotFourthMoment;
+ } else if (value instanceof ArgMinMaxObject) {
+ return ObjectType.ArgMinMaxObject;
} else {
throw new IllegalArgumentException("Unsupported type of value: " + value.getClass().getSimpleName());
}
@@ -1199,6 +1203,37 @@ public class ObjectSerDeUtils {
}
};
+ public static final ObjectSerDe<ArgMinMaxObject> ARG_MIN_MAX_OBJECT_SER_DE =
+ new ObjectSerDe<ArgMinMaxObject>() {
+
+ @Override
+ public byte[] serialize(ArgMinMaxObject value) {
+ try {
+ return value.toBytes();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public ArgMinMaxObject deserialize(byte[] bytes) {
+ try {
+ return ArgMinMaxObject.fromBytes(bytes);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public ArgMinMaxObject deserialize(ByteBuffer byteBuffer) {
+ try {
+ return ArgMinMaxObject.fromByteBuffer(byteBuffer);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ };
+
// NOTE: DO NOT change the order, it has to be the same order as the ObjectType
//@formatter:off
private static final ObjectSerDe[] SER_DES = {
@@ -1236,7 +1271,8 @@ public class ObjectSerDeUtils {
STRING_LONG_PAIR_SER_DE,
COVARIANCE_TUPLE_OBJECT_SER_DE,
VARIANCE_TUPLE_OBJECT_SER_DE,
- PINOT_FOURTH_MOMENT_OBJECT_SER_DE
+ PINOT_FOURTH_MOMENT_OBJECT_SER_DE,
+ ARG_MIN_MAX_OBJECT_SER_DE,
};
//@formatter:on
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index 89e4a88c7a..ba3fc837c4 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -325,6 +325,18 @@ public class AggregationFunctionFactory {
return new FourthMomentAggregationFunction(firstArgument, FourthMomentAggregationFunction.Type.KURTOSIS);
case FOURTHMOMENT:
return new FourthMomentAggregationFunction(firstArgument, FourthMomentAggregationFunction.Type.MOMENT);
+ case PARENTARGMAX:
+ return new ParentArgMinMaxAggregationFunction(arguments, true);
+ case PARENTARGMIN:
+ return new ParentArgMinMaxAggregationFunction(arguments, false);
+ case CHILDARGMAX:
+ return new ChildArgMinMaxAggregationFunction(arguments, true);
+ case CHILDARGMIN:
+ return new ChildArgMinMaxAggregationFunction(arguments, false);
+ case ARGMAX:
+ case ARGMIN:
+ throw new IllegalArgumentException("Aggregation function: " + function
+ + " is only supported in selection without alias.");
default:
throw new IllegalArgumentException();
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ChildAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ChildAggregationFunction.java
new file mode 100644
index 0000000000..e4d88226a2
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ChildAggregationFunction.java
@@ -0,0 +1,160 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.DummyAggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.DummyGroupByResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import org.apache.pinot.spi.utils.CommonConstants;
+
+
+/**
+ * Child aggregation function is used for a result placeholder during the query processing,
+ * It holds the position of the original aggregation function in the query
+ * and use its name to denote which parent aggregation function it belongs to.
+ * The name also serves as the key to retrieve the result from the parent aggregation function
+ * result holder.
+ * Please look at getResultColumnName() for the detailed format of the name.
+ * Please look at ArgMinMaxRewriter as an example of how a child aggregation function is created.
+ */
+public abstract class ChildAggregationFunction implements AggregationFunction<Long, Long> {
+
+ private static final int CHILD_AGGREGATION_FUNCTION_ID_OFFSET = 0;
+ private static final int CHILD_AGGREGATION_FUNCTION_COLUMN_KEY_OFFSET = 1;
+ private final ExpressionContext _childFunctionKeyInParent;
+ private final List<ExpressionContext> _resultNameOperands;
+ private final ExpressionContext _childFunctionID;
+
+ ChildAggregationFunction(List<ExpressionContext> operands) {
+ _childFunctionID = operands.get(CHILD_AGGREGATION_FUNCTION_ID_OFFSET);
+ _childFunctionKeyInParent = operands.get(CHILD_AGGREGATION_FUNCTION_COLUMN_KEY_OFFSET);
+ _resultNameOperands = operands.subList(CHILD_AGGREGATION_FUNCTION_COLUMN_KEY_OFFSET + 1, operands.size());
+ }
+
+ @Override
+ public List<ExpressionContext> getInputExpressions() {
+ ArrayList<ExpressionContext> expressionContexts = new ArrayList<>();
+ expressionContexts.add(_childFunctionID);
+ expressionContexts.add(_childFunctionKeyInParent);
+ expressionContexts.addAll(_resultNameOperands);
+ return expressionContexts;
+ }
+
+ @Override
+ public final AggregationResultHolder createAggregationResultHolder() {
+ return new DummyAggregationResultHolder();
+ }
+
+ @Override
+ public final GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) {
+ return new DummyGroupByResultHolder();
+ }
+
+ @Override
+ public final void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ }
+
+ @Override
+ public final void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ }
+
+ @Override
+ public final void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ }
+
+ @Override
+ public final Long extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
+ return 0L;
+ }
+
+ @Override
+ public final Long extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
+ return 0L;
+ }
+
+ @Override
+ public final Long merge(Long intermediateResult1, Long intermediateResult2) {
+ return 0L;
+ }
+
+ @Override
+ public final DataSchema.ColumnDataType getIntermediateResultColumnType() {
+ return DataSchema.ColumnDataType.LONG;
+ }
+
+ @Override
+ public final DataSchema.ColumnDataType getFinalResultColumnType() {
+ return DataSchema.ColumnDataType.UNKNOWN;
+ }
+
+ @Override
+ public final Long extractFinalResult(Long longValue) {
+ return 0L;
+ }
+
+ /**
+ * The name of the column as follows:
+ * CHILD_AGGREGATION_NAME_PREFIX + actual function type + operands + CHILD_AGGREGATION_SEPERATOR
+ * + actual function type + parent aggregation function id + CHILD_KEY_SEPERATOR + column key in parent function
+ * e.g. if the child aggregation function is "argmax(0,a,b,x)", the name of the column is
+ * "pinotchildaggregationargmax(a,b,x)@argmax0_x"
+ */
+ @Override
+ public final String getResultColumnName() {
+ String type = getType().getName().toLowerCase();
+ return CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX
+ // above is the prefix for all child aggregation functions
+
+ + type + "(" + _resultNameOperands.stream().map(ExpressionContext::toString)
+ .collect(Collectors.joining(",")) + ")"
+ // above is the actual child aggregation function name we want to return to the user
+
+ + CommonConstants.RewriterConstants.CHILD_AGGREGATION_SEPERATOR
+ + type
+ + _childFunctionID.getLiteral().getStringValue()
+ + CommonConstants.RewriterConstants.CHILD_KEY_SEPERATOR
+ + _childFunctionKeyInParent.toString();
+ // above is the column key in the parent aggregation function
+ }
+
+ @Override
+ public final String toExplainString() {
+ StringBuilder stringBuilder = new StringBuilder(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX)
+ .append("_").append(getType().getName()).append('(');
+ int numArguments = getInputExpressions().size();
+ if (numArguments > 0) {
+ stringBuilder.append(getInputExpressions().get(0).toString());
+ for (int i = 1; i < numArguments; i++) {
+ stringBuilder.append(", ").append(getInputExpressions().get(i).toString());
+ }
+ }
+ return stringBuilder.append(')').toString();
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ChildArgMinMaxAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ChildArgMinMaxAggregationFunction.java
new file mode 100644
index 0000000000..408941bb30
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ChildArgMinMaxAggregationFunction.java
@@ -0,0 +1,39 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import java.util.List;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.segment.spi.AggregationFunctionType;
+
+
+public class ChildArgMinMaxAggregationFunction extends ChildAggregationFunction {
+
+ private final boolean _isMax;
+
+ public ChildArgMinMaxAggregationFunction(List<ExpressionContext> operands, boolean isMax) {
+ super(operands);
+ _isMax = isMax;
+ }
+
+ @Override
+ public AggregationFunctionType getType() {
+ return _isMax ? AggregationFunctionType.ARGMAX : AggregationFunctionType.ARGMIN;
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ParentAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ParentAggregationFunction.java
new file mode 100644
index 0000000000..831cf2d766
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ParentAggregationFunction.java
@@ -0,0 +1,71 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import java.util.List;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.query.aggregation.utils.ParentAggregationFunctionResultObject;
+import org.apache.pinot.spi.utils.CommonConstants;
+
+
+/**
+ * Base class for parent aggregation functions. A parent aggregation function is an aggregation function
+ * whose result is a nested data block containing multiple columns, each of which corresponds to a child
+ * aggregation function's result.
+ */
+public abstract class ParentAggregationFunction<I, F extends ParentAggregationFunctionResultObject>
+ implements AggregationFunction<I, F> {
+
+ protected static final int PARENT_AGGREGATION_FUNCTION_ID_OFFSET = 0;
+ protected List<ExpressionContext> _arguments;
+
+ ParentAggregationFunction(List<ExpressionContext> arguments) {
+ _arguments = arguments;
+ }
+
+ @Override
+ public final DataSchema.ColumnDataType getFinalResultColumnType() {
+ return DataSchema.ColumnDataType.OBJECT;
+ }
+
+ // The name of the column is the prefix of the parent aggregation function + the name of the
+ // aggregation function + the id of the parent aggregation function
+ // e.g. if the parent aggregation function is "argmax(0,3,a,b,c,x,y,z)", the name of the column is
+ // "pinotparentaggregationargmax0"
+ @Override
+ public final String getResultColumnName() {
+ return CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX
+ + getType().getName().toLowerCase()
+ + _arguments.get(PARENT_AGGREGATION_FUNCTION_ID_OFFSET).getLiteral().getIntValue();
+ }
+
+ public final String toExplainString() {
+ StringBuilder stringBuilder = new StringBuilder(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX)
+ .append("_").append(getType().getName()).append('(');
+ int numArguments = _arguments.size();
+ if (numArguments > 0) {
+ stringBuilder.append(_arguments.get(0).toString());
+ for (int i = 1; i < numArguments; i++) {
+ stringBuilder.append(", ").append(_arguments.get(i).toString());
+ }
+ }
+ return stringBuilder.append(')').toString();
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ParentArgMinMaxAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ParentArgMinMaxAggregationFunction.java
new file mode 100644
index 0000000000..9a735f10cc
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ParentArgMinMaxAggregationFunction.java
@@ -0,0 +1,432 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import com.google.common.base.Preconditions;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
+import org.apache.pinot.core.query.aggregation.utils.argminmax.ArgMinMaxMeasuringValSetWrapper;
+import org.apache.pinot.core.query.aggregation.utils.argminmax.ArgMinMaxObject;
+import org.apache.pinot.core.query.aggregation.utils.argminmax.ArgMinMaxProjectionValSetWrapper;
+import org.apache.pinot.segment.spi.AggregationFunctionType;
+
+
+public class ParentArgMinMaxAggregationFunction extends ParentAggregationFunction<ArgMinMaxObject, ArgMinMaxObject> {
+
+ // list of columns that we do min/max on
+ private final List<ExpressionContext> _measuringColumns;
+ // list of columns that we project based on the min/max value
+ private final List<ExpressionContext> _projectionColumns;
+ // true if we are doing argmax, false if we are doing argmin
+ private final boolean _isMax;
+ // the id of the function, this is to associate the result of the parent aggregation function with the
+ // child aggregation functions having the same type(argmin/argmax) and measuring columns
+ private final ExpressionContext _functionIdContext;
+ private final ExpressionContext _numMeasuringColumnContext;
+ // number of columns that we do min/max on
+ private final int _numMeasuringColumns;
+ // number of columns that we project based on the min/max value
+ private final int _numProjectionColumns;
+
+ // The following variable need to be initialized
+
+ // The wrapper classes for the block value sets
+ private final ThreadLocal<List<ArgMinMaxMeasuringValSetWrapper>> _argMinMaxWrapperMeasuringColumnSets =
+ ThreadLocal.withInitial(ArrayList::new);
+ private final ThreadLocal<List<ArgMinMaxProjectionValSetWrapper>> _argMinMaxWrapperProjectionColumnSets =
+ ThreadLocal.withInitial(ArrayList::new);
+ // The schema for the measuring columns and projection columns
+ private final ThreadLocal<DataSchema> _measuringColumnSchema = new ThreadLocal<>();
+ private final ThreadLocal<DataSchema> _projectionColumnSchema = new ThreadLocal<>();
+ // If the schemas are initialized
+ private final ThreadLocal<Boolean> _schemaInitialized = ThreadLocal.withInitial(() -> false);
+
+ public ParentArgMinMaxAggregationFunction(List<ExpressionContext> arguments, boolean isMax) {
+
+ super(arguments);
+ _isMax = isMax;
+ _functionIdContext = arguments.get(0);
+
+ _numMeasuringColumnContext = arguments.get(1);
+ _numMeasuringColumns = _numMeasuringColumnContext.getLiteral().getIntValue();
+
+ _measuringColumns = arguments.subList(2, 2 + _numMeasuringColumns);
+ _projectionColumns = arguments.subList(2 + _numMeasuringColumns, arguments.size());
+ _numProjectionColumns = _projectionColumns.size();
+ }
+
+ @Override
+ public AggregationFunctionType getType() {
+ return _isMax ? AggregationFunctionType.ARGMAX : AggregationFunctionType.ARGMIN;
+ }
+
+ @Override
+ public List<ExpressionContext> getInputExpressions() {
+ ArrayList<ExpressionContext> expressionContexts = new ArrayList<>();
+ expressionContexts.add(_functionIdContext);
+ expressionContexts.add(_numMeasuringColumnContext);
+ expressionContexts.addAll(_measuringColumns);
+ expressionContexts.addAll(_projectionColumns);
+ return expressionContexts;
+ }
+
+ @Override
+ public AggregationResultHolder createAggregationResultHolder() {
+ return new ObjectAggregationResultHolder();
+ }
+
+ @Override
+ public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) {
+ return new ObjectGroupByResultHolder(initialCapacity, maxCapacity);
+ }
+
+ @SuppressWarnings("LoopStatementThatDoesntLoop")
+ @Override
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+
+ ArgMinMaxObject argMinMaxObject = aggregationResultHolder.getResult();
+
+ if (argMinMaxObject == null) {
+ initializeWithNewDataBlocks(blockValSetMap);
+ argMinMaxObject = new ArgMinMaxObject(_measuringColumnSchema.get(), _projectionColumnSchema.get());
+ }
+
+ List<Integer> rowIds = new ArrayList<>();
+ for (int i = 0; i < length; i++) {
+ int compareResult = argMinMaxObject.compareAndSetKey(_argMinMaxWrapperMeasuringColumnSets.get(), i, _isMax);
+ if (compareResult == 0) {
+ // same key, add the rowId to the list
+ rowIds.add(i);
+ } else if (compareResult > 0) {
+ // new key is set, clear the list and add the new rowId
+ rowIds.clear();
+ rowIds.add(i);
+ }
+ }
+
+ // for all the rows that are associated with the extremum key, add the projection columns
+ for (Integer rowId : rowIds) {
+ argMinMaxObject.addVal(_argMinMaxWrapperProjectionColumnSets.get(), rowId);
+ }
+
+ aggregationResultHolder.setValue(argMinMaxObject);
+ }
+
+ // this method is called to initialize the schemas if they are not initialized
+ // and to set the new block value sets for the wrapper classes
+ private void initializeWithNewDataBlocks(Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ if (blockValSetMap == null) {
+ initializeForEmptyDocSet();
+ return;
+ }
+
+ // if the schema is already initialized, just update with the new block value sets
+ if (_schemaInitialized.get()) {
+ for (int i = 0; i < _numMeasuringColumns; i++) {
+ _argMinMaxWrapperMeasuringColumnSets.get().get(i).setNewBlock(blockValSetMap.get(_measuringColumns.get(i)));
+ }
+ for (int i = 0; i < _numProjectionColumns; i++) {
+ _argMinMaxWrapperProjectionColumnSets.get().get(i).setNewBlock(blockValSetMap.get(_projectionColumns.get(i)));
+ }
+ return;
+ }
+ // the schema is initialized only once
+ _schemaInitialized.set(true);
+ // setup measuring column names and types
+ initializeMeasuringColumnValSet(blockValSetMap);
+ // setup projection column names and types
+ initializeProjectionColumnValSet(blockValSetMap);
+ }
+
+ private void initializeProjectionColumnValSet(Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ List<ArgMinMaxProjectionValSetWrapper> argMinMaxWrapperProjectionColumnSets =
+ _argMinMaxWrapperProjectionColumnSets.get();
+ String[] projectionColNames = new String[_projectionColumns.size()];
+ DataSchema.ColumnDataType[] projectionColTypes = new DataSchema.ColumnDataType[_projectionColumns.size()];
+ for (int i = 0; i < _projectionColumns.size(); i++) {
+ projectionColNames[i] = _projectionColumns.get(i).toString();
+ ExpressionContext projectionColumn = _projectionColumns.get(i);
+ BlockValSet blockValSet = blockValSetMap.get(projectionColumn);
+ if (blockValSet.isSingleValue()) {
+ switch (blockValSet.getValueType()) {
+ case INT:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(true, DataSchema.ColumnDataType.INT, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.INT;
+ break;
+ case BOOLEAN:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(true, DataSchema.ColumnDataType.BOOLEAN, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.INT;
+ break;
+ case LONG:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(true, DataSchema.ColumnDataType.LONG, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.LONG;
+ break;
+ case TIMESTAMP:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(true, DataSchema.ColumnDataType.TIMESTAMP, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.LONG;
+ break;
+ case FLOAT:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(true, DataSchema.ColumnDataType.FLOAT, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.FLOAT;
+ break;
+ case DOUBLE:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(true, DataSchema.ColumnDataType.DOUBLE, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.DOUBLE;
+ break;
+ case STRING:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(true, DataSchema.ColumnDataType.STRING, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.STRING;
+ break;
+ case JSON:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(true, DataSchema.ColumnDataType.JSON, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.STRING;
+ break;
+ case BYTES:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(true, DataSchema.ColumnDataType.BYTES, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.BYTES;
+ break;
+ case BIG_DECIMAL:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(true, DataSchema.ColumnDataType.BIG_DECIMAL, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.BIG_DECIMAL;
+ break;
+ default:
+ throw new IllegalStateException(
+ "Cannot compute ArgMinMax projection on non-comparable type: " + blockValSet.getValueType());
+ }
+ } else {
+ switch (blockValSet.getValueType()) {
+ case INT:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(false, DataSchema.ColumnDataType.INT_ARRAY, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.INT_ARRAY;
+ break;
+ case BOOLEAN:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(false, DataSchema.ColumnDataType.BOOLEAN_ARRAY, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.INT_ARRAY;
+ break;
+ case LONG:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(false, DataSchema.ColumnDataType.LONG_ARRAY, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.LONG_ARRAY;
+ break;
+ case TIMESTAMP:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(false, DataSchema.ColumnDataType.TIMESTAMP_ARRAY, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.LONG_ARRAY;
+ break;
+ case FLOAT:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(false, DataSchema.ColumnDataType.FLOAT_ARRAY, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.FLOAT_ARRAY;
+ break;
+ case DOUBLE:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(false, DataSchema.ColumnDataType.DOUBLE_ARRAY, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.DOUBLE_ARRAY;
+ break;
+ case STRING:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(false, DataSchema.ColumnDataType.STRING_ARRAY, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.STRING_ARRAY;
+ break;
+ case BYTES:
+ argMinMaxWrapperProjectionColumnSets.add(
+ new ArgMinMaxProjectionValSetWrapper(false, DataSchema.ColumnDataType.BYTES_ARRAY, blockValSet));
+ projectionColTypes[i] = DataSchema.ColumnDataType.BYTES_ARRAY;
+ break;
+ default:
+ throw new IllegalStateException(
+ "Cannot compute ArgMinMax projection on non-comparable type: " + blockValSet.getValueType());
+ }
+ }
+ }
+ // setup measuring column schema
+ _projectionColumnSchema.set(new DataSchema(projectionColNames, projectionColTypes));
+ }
+
+ private void initializeMeasuringColumnValSet(Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ List<ArgMinMaxMeasuringValSetWrapper> argMinMaxWrapperMeasuringColumnSets =
+ _argMinMaxWrapperMeasuringColumnSets.get();
+ String[] measuringColNames = new String[_numMeasuringColumns];
+ DataSchema.ColumnDataType[] measuringColTypes = new DataSchema.ColumnDataType[_numMeasuringColumns];
+ for (int i = 0; i < _numMeasuringColumns; i++) {
+ measuringColNames[i] = _measuringColumns.get(i).toString();
+ ExpressionContext measuringColumn = _measuringColumns.get(i);
+ BlockValSet blockValSet = blockValSetMap.get(measuringColumn);
+ Preconditions.checkState(blockValSet.isSingleValue(), "ArgMinMax only supports single-valued"
+ + " measuring columns");
+ switch (blockValSet.getValueType()) {
+ case INT:
+ argMinMaxWrapperMeasuringColumnSets.add(
+ new ArgMinMaxMeasuringValSetWrapper(true, DataSchema.ColumnDataType.INT, blockValSet));
+ measuringColTypes[i] = DataSchema.ColumnDataType.INT;
+ break;
+ case BOOLEAN:
+ argMinMaxWrapperMeasuringColumnSets.add(
+ new ArgMinMaxMeasuringValSetWrapper(true, DataSchema.ColumnDataType.BOOLEAN, blockValSet));
+ measuringColTypes[i] = DataSchema.ColumnDataType.INT;
+ break;
+ case LONG:
+ argMinMaxWrapperMeasuringColumnSets.add(
+ new ArgMinMaxMeasuringValSetWrapper(true, DataSchema.ColumnDataType.LONG, blockValSet));
+ measuringColTypes[i] = DataSchema.ColumnDataType.LONG;
+ break;
+ case TIMESTAMP:
+ argMinMaxWrapperMeasuringColumnSets.add(
+ new ArgMinMaxMeasuringValSetWrapper(true, DataSchema.ColumnDataType.TIMESTAMP, blockValSet));
+ measuringColTypes[i] = DataSchema.ColumnDataType.LONG;
+ break;
+ case FLOAT:
+ argMinMaxWrapperMeasuringColumnSets.add(
+ new ArgMinMaxMeasuringValSetWrapper(true, DataSchema.ColumnDataType.FLOAT, blockValSet));
+ measuringColTypes[i] = DataSchema.ColumnDataType.FLOAT;
+ break;
+ case DOUBLE:
+ argMinMaxWrapperMeasuringColumnSets.add(
+ new ArgMinMaxMeasuringValSetWrapper(true, DataSchema.ColumnDataType.DOUBLE, blockValSet));
+ measuringColTypes[i] = DataSchema.ColumnDataType.DOUBLE;
+ break;
+ case STRING:
+ argMinMaxWrapperMeasuringColumnSets.add(
+ new ArgMinMaxMeasuringValSetWrapper(true, DataSchema.ColumnDataType.STRING, blockValSet));
+ measuringColTypes[i] = DataSchema.ColumnDataType.STRING;
+ break;
+ case BIG_DECIMAL:
+ argMinMaxWrapperMeasuringColumnSets.add(
+ new ArgMinMaxMeasuringValSetWrapper(true, DataSchema.ColumnDataType.BIG_DECIMAL, blockValSet));
+ measuringColTypes[i] = DataSchema.ColumnDataType.BIG_DECIMAL;
+ break;
+ default:
+ throw new IllegalStateException(
+ "Cannot compute ArgMinMax measuring on non-comparable type: " + blockValSet.getValueType());
+ }
+ }
+ // setup measuring column schema
+ _measuringColumnSchema.set(new DataSchema(measuringColNames, measuringColTypes));
+ }
+
+ // This method is called when the docIdSet is empty meaning that there are no rows that match the filter.
+ private void initializeForEmptyDocSet() {
+ if (_schemaInitialized.get()) {
+ return;
+ }
+ _schemaInitialized.set(true);
+ String[] measuringColNames = new String[_numMeasuringColumns];
+ DataSchema.ColumnDataType[] measuringColTypes = new DataSchema.ColumnDataType[_numMeasuringColumns];
+ for (int i = 0; i < _numMeasuringColumns; i++) {
+ measuringColNames[i] = _measuringColumns.get(i).toString();
+ measuringColTypes[i] = DataSchema.ColumnDataType.STRING;
+ }
+
+ String[] projectionColNames = new String[_numProjectionColumns];
+ DataSchema.ColumnDataType[] projectionColTypes = new DataSchema.ColumnDataType[_numProjectionColumns];
+ for (int i = 0; i < _numProjectionColumns; i++) {
+ projectionColNames[i] = _projectionColumns.get(i).toString();
+ projectionColTypes[i] = DataSchema.ColumnDataType.STRING;
+ }
+ _measuringColumnSchema.set(new DataSchema(measuringColNames, measuringColTypes));
+ _projectionColumnSchema.set(new DataSchema(projectionColNames, projectionColTypes));
+ }
+
+ @Override
+ public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ initializeWithNewDataBlocks(blockValSetMap);
+ for (int i = 0; i < length; i++) {
+ int groupKey = groupKeyArray[i];
+ updateGroupByResult(groupByResultHolder, i, groupKey);
+ }
+ }
+
+ private void updateGroupByResult(GroupByResultHolder groupByResultHolder, int i, int groupKey) {
+ ArgMinMaxObject argMinMaxObject = groupByResultHolder.getResult(groupKey);
+ if (argMinMaxObject == null) {
+ argMinMaxObject = new ArgMinMaxObject(_measuringColumnSchema.get(), _projectionColumnSchema.get());
+ groupByResultHolder.setValueForKey(groupKey, argMinMaxObject);
+ }
+ int compareResult = argMinMaxObject.compareAndSetKey(_argMinMaxWrapperMeasuringColumnSets.get(), i, _isMax);
+ if (compareResult == 0) {
+ argMinMaxObject.addVal(_argMinMaxWrapperProjectionColumnSets.get(), i);
+ } else if (compareResult > 0) {
+ argMinMaxObject.setToNewVal(_argMinMaxWrapperProjectionColumnSets.get(), i);
+ }
+ }
+
+ @Override
+ public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
+ Map<ExpressionContext, BlockValSet> blockValSetMap) {
+ initializeWithNewDataBlocks(blockValSetMap);
+ for (int i = 0; i < length; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ updateGroupByResult(groupByResultHolder, i, groupKey);
+ }
+ }
+ }
+
+ @Override
+ public ArgMinMaxObject extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
+ ArgMinMaxObject argMinMaxObject = aggregationResultHolder.getResult();
+ if (argMinMaxObject == null) {
+ initializeWithNewDataBlocks(null);
+ return new ArgMinMaxObject(_measuringColumnSchema.get(), _projectionColumnSchema.get());
+ } else {
+ return argMinMaxObject;
+ }
+ }
+
+ @Override
+ public ArgMinMaxObject extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) {
+ return groupByResultHolder.getResult(groupKey);
+ }
+
+ @Override
+ public ArgMinMaxObject merge(ArgMinMaxObject intermediateResult1, ArgMinMaxObject intermediateResult2) {
+ return intermediateResult1.merge(intermediateResult2, _isMax);
+ }
+
+ @Override
+ public DataSchema.ColumnDataType getIntermediateResultColumnType() {
+ return DataSchema.ColumnDataType.OBJECT;
+ }
+
+ @Override
+ public ArgMinMaxObject extractFinalResult(ArgMinMaxObject argMinMaxObject) {
+ return argMinMaxObject;
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DummyAggregationResultHolder.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DummyAggregationResultHolder.java
new file mode 100644
index 0000000000..545cd7973c
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DummyAggregationResultHolder.java
@@ -0,0 +1,55 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.groupby;
+
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+
+
+/**
+ * Placeholder AggregationResultHolder that does noop
+ * This is used for ChildAggregationFunction
+ */
+public class DummyAggregationResultHolder implements AggregationResultHolder {
+ @Override
+ public void setValue(double value) {
+ }
+
+ @Override
+ public void setValue(int value) {
+ }
+
+ @Override
+ public void setValue(Object value) {
+ }
+
+ @Override
+ public double getDoubleResult() {
+ return 0;
+ }
+
+ @Override
+ public int getIntResult() {
+ return 0;
+ }
+
+ @Override
+ public <T> T getResult() {
+ return null;
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DummyGroupByResultHolder.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DummyGroupByResultHolder.java
new file mode 100644
index 0000000000..f76e8967cf
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DummyGroupByResultHolder.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.groupby;
+
+/**
+ * Placeholder GroupByResultHolder that does noop
+ * This is used for ChildAggregationFunction
+ */
+public class DummyGroupByResultHolder implements GroupByResultHolder {
+ @Override
+ public void setValueForKey(int groupKey, double value) {
+ }
+
+ @Override
+ public void setValueForKey(int groupKey, int value) {
+ }
+
+ @Override
+ public void setValueForKey(int groupKey, Object value) {
+ }
+
+ @Override
+ public double getDoubleResult(int groupKey) {
+ return 0;
+ }
+
+ @Override
+ public int getIntResult(int groupKey) {
+ return 0;
+ }
+
+ @Override
+ public <T> T getResult(int groupKey) {
+ return null;
+ }
+
+ @Override
+ public void ensureCapacity(int capacity) {
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/ParentAggregationFunctionResultObject.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/ParentAggregationFunctionResultObject.java
new file mode 100644
index 0000000000..ee441b74e5
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/ParentAggregationFunctionResultObject.java
@@ -0,0 +1,42 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.utils;
+
+import java.io.Serializable;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.query.utils.rewriter.ParentAggregationResultRewriter;
+
+
+/**
+ * Interface for the result of a parent aggregation function, as can be used to populate the results of corresponding
+ * of child aggregation functions. Each child aggregation function will have a corresponding column in the result
+ * schema, please see {@link ParentAggregationResultRewriter} for more details.
+ */
+public interface ParentAggregationFunctionResultObject
+ extends Comparable<ParentAggregationFunctionResultObject>, Serializable {
+
+ // get the nested value of the field at the given row, column
+ Object getField(int rowId, int colId);
+
+ // get total number of rows
+ int getNumberOfRows();
+
+ // get the nested schema of the result
+ DataSchema getSchema();
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/argminmax/ArgMinMaxMeasuringValSetWrapper.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/argminmax/ArgMinMaxMeasuringValSetWrapper.java
new file mode 100644
index 0000000000..0b33aa4363
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/argminmax/ArgMinMaxMeasuringValSetWrapper.java
@@ -0,0 +1,77 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.utils.argminmax;
+
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.BlockValSet;
+
+
+/**
+ * Wrapper class for measuring columns in argmin/max aggregation function.
+ * Meanly used to do comparison without boxing primitive types.
+ */
+public class ArgMinMaxMeasuringValSetWrapper extends ArgMinMaxWrapperValSet {
+
+ public ArgMinMaxMeasuringValSetWrapper(boolean isSingleValue, DataSchema.ColumnDataType dataType,
+ BlockValSet blockValSet) {
+ super(dataType, isSingleValue);
+ setNewBlock(blockValSet);
+ }
+
+ public Comparable getComparable(int i) {
+ switch (_dataType) {
+ case INT:
+ case BOOLEAN:
+ return _intValues[i];
+ case LONG:
+ case TIMESTAMP:
+ return _longValues[i];
+ case FLOAT:
+ return _floatValues[i];
+ case DOUBLE:
+ return _doublesValues[i];
+ case STRING:
+ case BIG_DECIMAL:
+ return (Comparable) _objectsValues[i];
+ default:
+ throw new IllegalStateException("Unsupported data type: " + _dataType);
+ }
+ }
+
+ public int compare(int i, Object o) {
+ switch (_dataType) {
+ case INT:
+ case BOOLEAN:
+ return Integer.compare((Integer) o, _intValues[i]);
+ case LONG:
+ case TIMESTAMP:
+ return Long.compare((Long) o, _longValues[i]);
+ case FLOAT:
+ return Float.compare((Float) o, _floatValues[i]);
+ case DOUBLE:
+ return Double.compare((Double) o, _doublesValues[i]);
+ case STRING:
+ return ((String) o).compareTo((String) _objectsValues[i]);
+ case BIG_DECIMAL:
+ return ((java.math.BigDecimal) o).compareTo((java.math.BigDecimal) _objectsValues[i]);
+ default:
+ throw new IllegalStateException("Unsupported data type in comparison" + _dataType);
+ }
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/argminmax/ArgMinMaxObject.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/argminmax/ArgMinMaxObject.java
new file mode 100644
index 0000000000..176902a4f4
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/argminmax/ArgMinMaxObject.java
@@ -0,0 +1,353 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.pinot.core.query.aggregation.utils.argminmax;
+
+import com.google.common.base.Preconditions;
+import java.io.ByteArrayOutputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import javax.annotation.Nonnull;
+import org.apache.pinot.common.datablock.DataBlock;
+import org.apache.pinot.common.datablock.DataBlockUtils;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.datablock.DataBlockBuilder;
+import org.apache.pinot.core.query.aggregation.utils.ParentAggregationFunctionResultObject;
+
+
+public class ArgMinMaxObject implements ParentAggregationFunctionResultObject {
+
+ // if the object is created but not yet populated, this happens e.g. when a server has no data for
+ // the query and returns a default value
+ enum ObjectNullState {
+ NULL(0),
+ NON_NULL(1);
+
+ final int _state;
+
+ ObjectNullState(int i) {
+ _state = i;
+ }
+
+ int getState() {
+ return _state;
+ }
+ }
+
+ // if the object contains non null values
+ private boolean _isNull;
+
+ // if the value is stored in a mutable list, this is false only when the Object is deserialized from a byte buffer
+ // if the object is mutable, it means that the object is read only and the values are stored in
+ // _immutableMeasuringKeys and _immutableProjectionVals, otherwise we read and write from _extremumMeasuringKeys
+ // and _extremumProjectionValues
+ private boolean _mutable;
+
+ // the schema of the measuring columns
+ private final DataSchema _measuringSchema;
+ // the schema of the projection columns
+ private final DataSchema _projectionSchema;
+
+ // the size of the extremum key cols and value cols
+ private final int _sizeOfExtremumMeasuringKeys;
+ private final int _sizeOfExtremumProjectionVals;
+
+ // the current extremum keys, keys are the extremum values of the measuring columns,
+ // used for comparison
+ private Comparable[] _extremumMeasuringKeys = null;
+ // the current extremum values, values are the values of the projection columns
+ // associated with the minimum measuring column, used for projection
+ private final List<Object[]> _extremumProjectionValues = new ArrayList<>();
+
+ // used for ser/de
+ private DataBlock _immutableMeasuringKeys;
+ private DataBlock _immutableProjectionVals;
+
+ public ArgMinMaxObject(DataSchema measuringSchema, DataSchema projectionSchema) {
+ _isNull = true;
+ _mutable = true;
+
+ _measuringSchema = measuringSchema;
+ _projectionSchema = projectionSchema;
+
+ _sizeOfExtremumMeasuringKeys = _measuringSchema.size();
+ _sizeOfExtremumProjectionVals = _projectionSchema.size();
+ }
+
+ public ArgMinMaxObject(ByteBuffer byteBuffer)
+ throws IOException {
+ _mutable = false;
+ _isNull = byteBuffer.getInt() == ObjectNullState.NULL.getState();
+ byteBuffer = byteBuffer.slice();
+ _immutableMeasuringKeys = DataBlockUtils.getDataBlock(byteBuffer);
+ byteBuffer = byteBuffer.slice();
+ _immutableProjectionVals = DataBlockUtils.getDataBlock(byteBuffer);
+
+ _measuringSchema = _immutableMeasuringKeys.getDataSchema();
+ _projectionSchema = _immutableProjectionVals.getDataSchema();
+
+ _sizeOfExtremumMeasuringKeys = _measuringSchema.size();
+ _sizeOfExtremumProjectionVals = _projectionSchema.size();
+ }
+
+ public static ArgMinMaxObject fromBytes(byte[] bytes)
+ throws IOException {
+ return fromByteBuffer(ByteBuffer.wrap(bytes));
+ }
+
+ public static ArgMinMaxObject fromByteBuffer(ByteBuffer byteBuffer)
+ throws IOException {
+ return new ArgMinMaxObject(byteBuffer);
+ }
+
+ // used for result serialization
+ @Nonnull
+ public byte[] toBytes()
+ throws IOException {
+ ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
+ DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);
+ if (_isNull) {
+ // serialize the null object with schemas
+ dataOutputStream.writeInt(ObjectNullState.NULL.getState());
+ _immutableMeasuringKeys = DataBlockBuilder.buildFromRows(Collections.emptyList(), _measuringSchema);
+ _immutableProjectionVals = DataBlockBuilder.buildFromRows(Collections.emptyList(), _projectionSchema);
+ } else {
+ dataOutputStream.writeInt(ObjectNullState.NON_NULL.getState());
+ _immutableMeasuringKeys =
+ DataBlockBuilder.buildFromRows(Collections.singletonList(_extremumMeasuringKeys), _measuringSchema);
+ _immutableProjectionVals = DataBlockBuilder.buildFromRows(_extremumProjectionValues, _projectionSchema);
+ }
+ dataOutputStream.write(_immutableMeasuringKeys.toBytes());
+ dataOutputStream.write(_immutableProjectionVals.toBytes());
+ return byteArrayOutputStream.toByteArray();
+ }
+
+ /**
+ * Used during segment processing
+ * Compare the current key with the new key, and return the comparison result.
+ * > 0: the key is replaced because the new key is the new extremum
+ * = 0: new key is the same as the current extremum
+ * < 0: current key is still the extremum
+ */
+ public int compareAndSetKey(List<ArgMinMaxMeasuringValSetWrapper> argMinMaxWrapperValSets, int offset,
+ boolean isMax) {
+ Preconditions.checkState(_mutable, "Cannot compare and set key after the object is serialized");
+ if (!_isNull) {
+ for (int i = 0; i < _sizeOfExtremumMeasuringKeys; i++) {
+ ArgMinMaxMeasuringValSetWrapper argMinMaxWrapperValSet = argMinMaxWrapperValSets.get(i);
+ int result = argMinMaxWrapperValSet.compare(offset, _extremumMeasuringKeys[i]);
+ if (result != 0) {
+ if (isMax ? result < 0 : result > 0) {
+ for (int j = 0; j < _sizeOfExtremumMeasuringKeys; j++) {
+ _extremumMeasuringKeys[j] = argMinMaxWrapperValSets.get(j).getComparable(offset);
+ }
+ return 1;
+ }
+ return -1;
+ }
+ }
+ } else {
+ _isNull = false;
+ _extremumMeasuringKeys = new Comparable[_sizeOfExtremumMeasuringKeys];
+ for (int i = 0; i < _sizeOfExtremumMeasuringKeys; i++) {
+ _extremumMeasuringKeys[i] = argMinMaxWrapperValSets.get(i).getComparable(offset);
+ }
+ }
+ return 0;
+ }
+
+ /**
+ * Used during segment processing with compareAndSetKey
+ * Set the vals to the new vals if the key is replaced.
+ */
+ public void setToNewVal(List<ArgMinMaxProjectionValSetWrapper> argMinMaxProjectionValSetWrappers, int offset) {
+ _extremumProjectionValues.clear();
+ addVal(argMinMaxProjectionValSetWrappers, offset);
+ }
+
+ /**
+ * Used during segment processing with compareAndSetKey
+ * Add the vals to the list of vals if the key is the same.
+ */
+ public void addVal(List<ArgMinMaxProjectionValSetWrapper> argMinMaxProjectionValSetWrappers, int offset) {
+ Object[] val = new Object[_projectionSchema.size()];
+ for (int i = 0; i < _projectionSchema.size(); i++) {
+ val[i] = argMinMaxProjectionValSetWrappers.get(i).getValue(offset);
+ }
+ _extremumProjectionValues.add(val);
+ }
+
+ public Comparable[] getExtremumKey() {
+ if (_mutable) {
+ return _extremumMeasuringKeys;
+ } else {
+ Comparable[] extremumKeys = new Comparable[_sizeOfExtremumMeasuringKeys];
+ for (int i = 0; i < _sizeOfExtremumMeasuringKeys; i++) {
+ switch (_measuringSchema.getColumnDataType(i)) {
+ case INT:
+ case BOOLEAN:
+ extremumKeys[i] = _immutableMeasuringKeys.getInt(0, i);
+ break;
+ case LONG:
+ case TIMESTAMP:
+ extremumKeys[i] = _immutableMeasuringKeys.getLong(0, i);
+ break;
+ case FLOAT:
+ extremumKeys[i] = _immutableMeasuringKeys.getFloat(0, i);
+ break;
+ case DOUBLE:
+ extremumKeys[i] = _immutableMeasuringKeys.getDouble(0, i);
+ break;
+ case STRING:
+ extremumKeys[i] = _immutableMeasuringKeys.getString(0, i);
+ break;
+ case BIG_DECIMAL:
+ extremumKeys[i] = _immutableMeasuringKeys.getBigDecimal(0, i);
+ break;
+ default:
+ throw new IllegalStateException("Unsupported data type: " + _measuringSchema.getColumnDataType(i));
+ }
+ }
+ return extremumKeys;
+ }
+ }
+
+ /**
+ * Get the field from a projection column
+ */
+ @Override
+ public Object getField(int rowId, int colId) {
+ if (_mutable) {
+ return _extremumProjectionValues.get(rowId)[colId];
+ } else {
+ switch (_projectionSchema.getColumnDataType(colId)) {
+ case BOOLEAN:
+ case INT:
+ return _immutableProjectionVals.getInt(rowId, colId);
+ case TIMESTAMP:
+ case LONG:
+ return _immutableProjectionVals.getLong(rowId, colId);
+ case FLOAT:
+ return _immutableProjectionVals.getFloat(rowId, colId);
+ case DOUBLE:
+ return _immutableProjectionVals.getDouble(rowId, colId);
+ case JSON:
+ case STRING:
+ return _immutableProjectionVals.getString(rowId, colId);
+ case BYTES:
+ return _immutableProjectionVals.getBytes(rowId, colId);
+ case BIG_DECIMAL:
+ return _immutableProjectionVals.getBigDecimal(rowId, colId);
+ case BOOLEAN_ARRAY:
+ case INT_ARRAY:
+ return _immutableProjectionVals.getIntArray(rowId, colId);
+ case TIMESTAMP_ARRAY:
+ case LONG_ARRAY:
+ return _immutableProjectionVals.getLongArray(rowId, colId);
+ case FLOAT_ARRAY:
+ return _immutableProjectionVals.getFloatArray(rowId, colId);
+ case DOUBLE_ARRAY:
+ return _immutableProjectionVals.getDoubleArray(rowId, colId);
+ case STRING_ARRAY:
+ case BYTES_ARRAY:
+ return _immutableProjectionVals.getStringArray(rowId, colId);
+ default:
+ throw new IllegalStateException("Unsupported data type: " + _projectionSchema.getColumnDataType(colId));
+ }
+ }
+ }
+
+ /**
+ * Merge two ArgMinMaxObjects
+ */
+ public ArgMinMaxObject merge(ArgMinMaxObject other, boolean isMax) {
+ if (_isNull && other._isNull) {
+ return this;
+ } else if (_isNull) {
+ return other;
+ } else if (other._isNull) {
+ return this;
+ } else {
+ int result;
+ Comparable[] key = getExtremumKey();
+ Comparable[] otherKey = other.getExtremumKey();
+ for (int i = 0; i < _sizeOfExtremumMeasuringKeys; i++) {
+ result = key[i].compareTo(otherKey[i]);
+ if (result != 0) {
+ // If the keys are not equal, return the object with the extremum key
+ if (isMax) {
+ return result > 0 ? this : other;
+ } else {
+ return result < 0 ? this : other;
+ }
+ }
+ }
+ // If the keys are equal, add the values of the other object to this object
+ if (!_mutable) {
+ // If the result is immutable, we need to copy the values from the serialized result to the mutable result
+ _mutable = true;
+ for (int i = 0; i < getNumberOfRows(); i++) {
+ Object[] val = new Object[_sizeOfExtremumProjectionVals];
+ for (int j = 0; j < _sizeOfExtremumProjectionVals; j++) {
+ val[j] = getField(i, j);
+ }
+ _extremumProjectionValues.add(val);
+ }
+ }
+ for (int i = 0; i < other.getNumberOfRows(); i++) {
+ Object[] val = new Object[_sizeOfExtremumProjectionVals];
+ for (int j = 0; j < _sizeOfExtremumProjectionVals; j++) {
+ val[j] = other.getField(i, j);
+ }
+ _extremumProjectionValues.add(val);
+ }
+ return this;
+ }
+ }
+
+ /**
+ * get the number of rows in the projection data
+ */
+ @Override
+ public int getNumberOfRows() {
+ if (_mutable) {
+ return _extremumProjectionValues.size();
+ } else {
+ return _immutableProjectionVals.getNumberOfRows();
+ }
+ }
+
+ /**
+ * return the schema of the projection data
+ */
+ @Override
+ public DataSchema getSchema() {
+ // the final parent aggregation result only cares about the projection columns
+ return _projectionSchema;
+ }
+
+ @Override
+ public int compareTo(ParentAggregationFunctionResultObject o) {
+ return this.getNumberOfRows() - o.getNumberOfRows();
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/argminmax/ArgMinMaxProjectionValSetWrapper.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/argminmax/ArgMinMaxProjectionValSetWrapper.java
new file mode 100644
index 0000000000..c8676ba716
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/argminmax/ArgMinMaxProjectionValSetWrapper.java
@@ -0,0 +1,70 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.utils.argminmax;
+
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.BlockValSet;
+
+
+/**
+ * Wrapper class for projection block value set for argmin/max aggregation function.
+ * Used to get the value from val set of different data types.
+ */
+public class ArgMinMaxProjectionValSetWrapper extends ArgMinMaxWrapperValSet {
+
+ public ArgMinMaxProjectionValSetWrapper(boolean isSingleValue, DataSchema.ColumnDataType dataType,
+ BlockValSet blockValSet) {
+ super(dataType, isSingleValue);
+ setNewBlock(blockValSet);
+ }
+
+ public Object getValue(int i) {
+ switch (_dataType) {
+ case INT:
+ case BOOLEAN:
+ return _intValues[i];
+ case LONG:
+ case TIMESTAMP:
+ return _longValues[i];
+ case FLOAT:
+ return _floatValues[i];
+ case DOUBLE:
+ return _doublesValues[i];
+ case STRING:
+ case BIG_DECIMAL:
+ case BYTES:
+ case JSON:
+ return _objectsValues[i];
+ case INT_ARRAY:
+ return _intValuesMV[i].length == 0 ? null : _intValuesMV[i];
+ case LONG_ARRAY:
+ case TIMESTAMP_ARRAY:
+ return _longValuesMV[i].length == 0 ? null : _longValuesMV[i];
+ case FLOAT_ARRAY:
+ return _floatValuesMV[i].length == 0 ? null : _floatValuesMV[i];
+ case DOUBLE_ARRAY:
+ return _doublesValuesMV[i].length == 0 ? null : _doublesValuesMV[i];
+ case STRING_ARRAY:
+ case BYTES_ARRAY:
+ return _objectsValuesMV[i].length == 0 ? null : _objectsValuesMV[i];
+ default:
+ throw new IllegalStateException("Unsupported data type: " + _dataType);
+ }
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/argminmax/ArgMinMaxWrapperValSet.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/argminmax/ArgMinMaxWrapperValSet.java
new file mode 100644
index 0000000000..5aa18eae5c
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/utils/argminmax/ArgMinMaxWrapperValSet.java
@@ -0,0 +1,106 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.utils.argminmax;
+
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.BlockValSet;
+
+
+/**
+ * Wrapper class for the value sets of the column to do argmin/max on.
+ * This class is used for type-generic implementation of argmin/max.
+ */
+public class ArgMinMaxWrapperValSet {
+ protected final DataSchema.ColumnDataType _dataType;
+ boolean _isSingleValue;
+ int[] _intValues;
+ long[] _longValues;
+ float[] _floatValues;
+ double[] _doublesValues;
+ Object[] _objectsValues;
+ int[][] _intValuesMV;
+ long[][] _longValuesMV;
+ float[][] _floatValuesMV;
+ double[][] _doublesValuesMV;
+ Object[][] _objectsValuesMV;
+
+ public ArgMinMaxWrapperValSet(
+ DataSchema.ColumnDataType dataType, boolean isSingleValue) {
+ _dataType = dataType;
+ _isSingleValue = isSingleValue;
+ }
+
+ public void setNewBlock(BlockValSet blockValSet) {
+ if (_isSingleValue) {
+ switch (_dataType) {
+ case INT:
+ case BOOLEAN:
+ _intValues = blockValSet.getIntValuesSV();
+ break;
+ case LONG:
+ case TIMESTAMP:
+ _longValues = blockValSet.getLongValuesSV();
+ break;
+ case FLOAT:
+ _floatValues = blockValSet.getFloatValuesSV();
+ break;
+ case DOUBLE:
+ _doublesValues = blockValSet.getDoubleValuesSV();
+ break;
+ case STRING:
+ case JSON:
+ _objectsValues = blockValSet.getStringValuesSV();
+ break;
+ case BIG_DECIMAL:
+ _objectsValues = blockValSet.getBigDecimalValuesSV();
+ break;
+ case BYTES:
+ _objectsValues = blockValSet.getBytesValuesSV();
+ break;
+ default:
+ throw new IllegalStateException("Unsupported data type: " + _dataType);
+ }
+ } else {
+ switch (_dataType) {
+ case INT_ARRAY:
+ case BOOLEAN_ARRAY:
+ _intValuesMV = blockValSet.getIntValuesMV();
+ break;
+ case LONG_ARRAY:
+ case TIMESTAMP_ARRAY:
+ _longValuesMV = blockValSet.getLongValuesMV();
+ break;
+ case FLOAT_ARRAY:
+ _floatValuesMV = blockValSet.getFloatValuesMV();
+ break;
+ case DOUBLE_ARRAY:
+ _doublesValuesMV = blockValSet.getDoubleValuesMV();
+ break;
+ case STRING_ARRAY:
+ _objectsValuesMV = blockValSet.getStringValuesMV();
+ break;
+ case BYTES_ARRAY:
+ _objectsValuesMV = blockValSet.getBytesValuesMV();
+ break;
+ default:
+ throw new IllegalStateException("Unsupported data type: " + _dataType);
+ }
+ }
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
index 811c9738ba..89e28d1924 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
@@ -34,6 +34,8 @@ import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.query.request.context.QueryContext;
+import org.apache.pinot.core.query.utils.rewriter.ResultRewriteUtils;
+import org.apache.pinot.core.query.utils.rewriter.RewriterResult;
import org.apache.pinot.core.transport.ServerRoutingInstance;
import org.apache.pinot.spi.trace.Tracing;
import org.roaringbitmap.RoaringBitmap;
@@ -142,12 +144,21 @@ public class AggregationDataTableReducer implements DataTableReducer {
new PostAggregationHandler(_queryContext, getPrePostAggregationDataSchema());
DataSchema dataSchema = postAggregationHandler.getResultDataSchema();
Object[] row = postAggregationHandler.getResult(finalResults);
+
+ RewriterResult resultRewriterResult =
+ ResultRewriteUtils.rewriteResult(dataSchema, Collections.singletonList(row));
+ dataSchema = resultRewriterResult.getDataSchema();
+ List<Object[]> rows = resultRewriterResult.getRows();
+
ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
int numColumns = columnDataTypes.length;
- for (int i = 0; i < numColumns; i++) {
- row[i] = columnDataTypes[i].format(row[i]);
+ for (Object[] rewrittenRow : rows) {
+ for (int j = 0; j < numColumns; j++) {
+ rewrittenRow[j] = columnDataTypes[j].format(rewrittenRow[j]);
+ }
}
- return new ResultTable(dataSchema, Collections.singletonList(row));
+
+ return new ResultTable(dataSchema, rows);
}
/**
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
index e7dae1c16f..654e6232a2 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
@@ -54,6 +54,8 @@ import org.apache.pinot.core.operator.combine.GroupByCombineOperator;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.query.request.context.QueryContext;
+import org.apache.pinot.core.query.utils.rewriter.ResultRewriteUtils;
+import org.apache.pinot.core.query.utils.rewriter.RewriterResult;
import org.apache.pinot.core.transport.ServerRoutingInstance;
import org.apache.pinot.core.util.GroupByUtils;
import org.apache.pinot.core.util.trace.TraceRunnable;
@@ -103,7 +105,8 @@ public class GroupByDataTableReducer implements DataTableReducer {
PostAggregationHandler postAggregationHandler =
new PostAggregationHandler(_queryContext, getPrePostAggregationDataSchema(dataSchema));
DataSchema resultDataSchema = postAggregationHandler.getResultDataSchema();
- brokerResponse.setResultTable(new ResultTable(resultDataSchema, Collections.emptyList()));
+ RewriterResult rewriterResult = ResultRewriteUtils.rewriteResult(resultDataSchema, Collections.emptyList());
+ brokerResponse.setResultTable(new ResultTable(rewriterResult.getDataSchema(), rewriterResult.getRows()));
return;
}
@@ -206,6 +209,11 @@ public class GroupByDataTableReducer implements DataTableReducer {
// Calculate final result rows after post aggregation
List<Object[]> resultRows = calculateFinalResultRows(postAggregationHandler, rows);
+ RewriterResult resultRewriterResult =
+ ResultRewriteUtils.rewriteResult(resultDataSchema, resultRows);
+ resultRows = resultRewriterResult.getRows();
+ resultDataSchema = resultRewriterResult.getDataSchema();
+
brokerResponseNative.setResultTable(new ResultTable(resultDataSchema, resultRows));
}
@@ -442,6 +450,11 @@ public class GroupByDataTableReducer implements DataTableReducer {
// Calculate final result rows after post aggregation
List<Object[]> resultRows = calculateFinalResultRows(postAggregationHandler, rows);
+ RewriterResult resultRewriterResult =
+ ResultRewriteUtils.rewriteResult(resultDataSchema, resultRows);
+ resultRows = resultRewriterResult.getRows();
+ resultDataSchema = resultRewriterResult.getDataSchema();
+
brokerResponseNative.setResultTable(new ResultTable(resultDataSchema, resultRows));
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/ParentAggregationResultRewriter.java b/pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/ParentAggregationResultRewriter.java
new file mode 100644
index 0000000000..d650303f8c
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/ParentAggregationResultRewriter.java
@@ -0,0 +1,222 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.utils.rewriter;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.query.aggregation.utils.ParentAggregationFunctionResultObject;
+import org.apache.pinot.spi.utils.CommonConstants;
+
+
+/**
+ * Used in aggregation and group-by queries with aggregation functions.
+ * Use the result of parent aggregation functions to populate the result of child aggregation functions.
+ * This implementation is based on the column names of the result schema.
+ * The result column name of a parent aggregation function has the following format:
+ * CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX + aggregationFunctionType + FunctionID
+ * The result column name of corresponding child aggregation function has the following format:
+ * CHILD_AGGREGATION_NAME_PREFIX + aggregationFunctionType + operands + CHILD_AGGREGATION_SEPERATOR
+ * + aggregationFunctionType + parent FunctionID + CHILD_KEY_SEPERATOR + column key in parent function
+ * This approach will not work with `AS` clauses as they alter the column names.
+ * TODO: Add support for `AS` clauses.
+ */
+public class ParentAggregationResultRewriter implements ResultRewriter {
+ public ParentAggregationResultRewriter() {
+ }
+
+ private static Map<String, ChildFunctionMapping> createChildFunctionMapping(DataSchema schema, Object[] row) {
+ Map<String, ChildFunctionMapping> childFunctionMapping = new HashMap<>();
+ for (int i = 0; i < schema.size(); i++) {
+ String columnName = schema.getColumnName(i);
+ if (columnName.startsWith(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX)) {
+ ParentAggregationFunctionResultObject parent = (ParentAggregationFunctionResultObject) row[i];
+
+ DataSchema nestedSchema = parent.getSchema();
+ for (int j = 0; j < nestedSchema.size(); j++) {
+ String childColumnKey = nestedSchema.getColumnName(j);
+ String originalChildFunctionKey =
+ columnName.substring(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX.length())
+ + CommonConstants.RewriterConstants.CHILD_KEY_SEPERATOR + childColumnKey;
+ // aggregationFunctionType + childFunctionID + CHILD_KEY_SEPERATOR + childFunctionKeyInParent
+ childFunctionMapping.put(originalChildFunctionKey, new ChildFunctionMapping(parent, j, i));
+ }
+ }
+ }
+ return childFunctionMapping;
+ }
+
+ public RewriterResult rewrite(DataSchema dataSchema, List<Object[]> rows) {
+
+ int numParentAggregationFunctions = 0;
+ // Count the number of parent aggregation functions
+ for (int i = 0; i < dataSchema.size(); i++) {
+ if (dataSchema.getColumnName(i).startsWith(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX)) {
+ numParentAggregationFunctions++;
+ }
+ }
+
+ if (numParentAggregationFunctions == 0) {
+ // no change to the result
+ return new RewriterResult(dataSchema, rows);
+ }
+
+ Map<String, ChildFunctionMapping> childFunctionMapping = null;
+ if (!rows.isEmpty()) {
+ // Create a mapping from the child aggregation function name to the child aggregation function
+ childFunctionMapping = createChildFunctionMapping(dataSchema, rows.get(0));
+ }
+
+ String[] newColumnNames = new String[dataSchema.size() - numParentAggregationFunctions];
+ DataSchema.ColumnDataType[] newColumnDataTypes
+ = new DataSchema.ColumnDataType[dataSchema.size() - numParentAggregationFunctions];
+
+ // Create a mapping from the function offset in the final aggregation result
+ // to its own/parent function offset in the original aggregation result
+ Map<Integer, Integer> aggregationFunctionIndexMapping = new HashMap<>();
+ // Create a set of the result indices of the child aggregation functions
+ Set<Integer> childAggregationFunctionIndices = new HashSet<>();
+ // Create a mapping from the result aggregation function index to the nested index of the
+ // child aggregation function in the parent aggregation function
+ Map<Integer, Integer> childAggregationFunctionNestedIndexMapping = new HashMap<>();
+ // Create a set of the result indices of the parent aggregation functions
+ Set<Integer> parentAggregationFunctionIndices = new HashSet<>();
+
+ for (int i = 0, j = 0; i < dataSchema.size(); i++) {
+ String columnName = dataSchema.getColumnName(i);
+ // Skip the parent aggregation functions
+ if (columnName.startsWith(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX)) {
+ parentAggregationFunctionIndices.add(i);
+ continue;
+ }
+
+ // for child aggregation functions and regular columns in the result
+ // create a new schema and populate the new column names and data types
+ // also populate the offset mappings used to rewrite the result
+ if (columnName.startsWith(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX)) {
+ // This is a child column of a parent aggregation function
+ String childAggregationFunctionNameWithKey =
+ columnName.substring(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX.length());
+ String[] s = childAggregationFunctionNameWithKey
+ .split(CommonConstants.RewriterConstants.CHILD_AGGREGATION_SEPERATOR);
+ newColumnNames[j] = s[0];
+
+ if (childFunctionMapping == null) {
+ newColumnDataTypes[j] = DataSchema.ColumnDataType.STRING;
+ j++;
+ continue;
+ }
+ ChildFunctionMapping childFunction = childFunctionMapping.get(s[1]);
+ newColumnDataTypes[j] = childFunction.getParent().getSchema()
+ .getColumnDataType(childFunction.getNestedOffset());
+
+ childAggregationFunctionNestedIndexMapping.put(j, childFunction.getNestedOffset());
+ childAggregationFunctionIndices.add(j);
+ aggregationFunctionIndexMapping.put(j, childFunction.getOffset());
+ } else {
+ // This is a regular column
+ newColumnNames[j] = columnName;
+ newColumnDataTypes[j] = dataSchema.getColumnDataType(i);
+
+ aggregationFunctionIndexMapping.put(j, i);
+ }
+ j++;
+ }
+
+ DataSchema newDataSchema = new DataSchema(newColumnNames, newColumnDataTypes);
+ List<Object[]> newRows = new ArrayList<>();
+
+ for (Object[] row : rows) {
+ int maxRows = parentAggregationFunctionIndices.stream().map(k -> {
+ ParentAggregationFunctionResultObject parentAggregationFunctionResultObject =
+ (ParentAggregationFunctionResultObject) row[k];
+ return parentAggregationFunctionResultObject.getNumberOfRows();
+ }).max(Integer::compareTo).orElse(0);
+ maxRows = maxRows == 0 ? 1 : maxRows;
+
+ List<Object[]> newRowsBuffer = new ArrayList<>();
+ for (int rowIter = 0; rowIter < maxRows; rowIter++) {
+ Object[] newRow = new Object[newDataSchema.size()];
+ for (int fieldIter = 0; fieldIter < newDataSchema.size(); fieldIter++) {
+ // If the field is a child aggregation function, extract the value from the parent result
+ if (childAggregationFunctionIndices.contains(fieldIter)) {
+ int offset = aggregationFunctionIndexMapping.get(fieldIter);
+ int nestedOffset = childAggregationFunctionNestedIndexMapping.get(fieldIter);
+ ParentAggregationFunctionResultObject parentAggregationFunctionResultObject =
+ (ParentAggregationFunctionResultObject) row[offset];
+ // If the parent result has more rows than the current row, extract the value from the row
+ if (rowIter < parentAggregationFunctionResultObject.getNumberOfRows()) {
+ newRow[fieldIter] = parentAggregationFunctionResultObject.getField(rowIter, nestedOffset);
+ } else {
+ newRow[fieldIter] = null;
+ }
+ } else { // If the field is a regular column, extract the value from the row, only the first row has value
+ newRow[fieldIter] = row[aggregationFunctionIndexMapping.get(fieldIter)];
+ }
+ }
+ newRowsBuffer.add(newRow);
+ }
+ newRows.addAll(newRowsBuffer);
+ }
+ return new RewriterResult(newDataSchema, newRows);
+ }
+
+ /**
+ * Mapping from child function key to
+ * 1. the parent result object,
+ * 2. offset of the parent result column in original result row,
+ * 3. the nested offset of the child function result in the parent data block
+ *
+ * For example, for a list of aggregation functions result:
+ * 0 1 2 3
+ * | | | |
+ * "child_argmin(a, b, x) ,child_argmin(a, b, y), child_argmin(a, b, z), parent_argmin(a, b, x, y, z)"
+ * | | |
+ * 0 1 2
+ * offset of the parent of child_argmin(a, b, y) is 3
+ * nested offset is child_argmin(a, b, y) is 1
+ */
+ private static class ChildFunctionMapping {
+ private final ParentAggregationFunctionResultObject _parent;
+ private final int _nestedOffset;
+ private final int _offset;
+
+ public ChildFunctionMapping(ParentAggregationFunctionResultObject parent, int nestedOffset, int offset) {
+ _parent = parent;
+ _nestedOffset = nestedOffset;
+ _offset = offset;
+ }
+
+ public int getOffset() {
+ return _offset;
+ }
+
+ public ParentAggregationFunctionResultObject getParent() {
+ return _parent;
+ }
+
+ public int getNestedOffset() {
+ return _nestedOffset;
+ }
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/ResultRewriteUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/ResultRewriteUtils.java
new file mode 100644
index 0000000000..3ae7f594e6
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/ResultRewriteUtils.java
@@ -0,0 +1,38 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.utils.rewriter;
+
+import java.util.List;
+import org.apache.pinot.common.utils.DataSchema;
+
+
+public class ResultRewriteUtils {
+
+ private ResultRewriteUtils() {
+ }
+
+ public static RewriterResult rewriteResult(DataSchema dataSchema, List<Object[]> rows) {
+ for (ResultRewriter resultRewriter : ResultRewriterFactory.getResultRewriter()) {
+ RewriterResult result = resultRewriter.rewrite(dataSchema, rows);
+ dataSchema = result.getDataSchema();
+ rows = result.getRows();
+ }
+ return new RewriterResult(dataSchema, rows);
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/ResultRewriter.java b/pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/ResultRewriter.java
new file mode 100644
index 0000000000..80975e8f1c
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/ResultRewriter.java
@@ -0,0 +1,30 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.utils.rewriter;
+
+import java.util.List;
+import org.apache.pinot.common.utils.DataSchema;
+
+
+/**
+ * Interface for rewriting the result of a query
+ */
+public interface ResultRewriter {
+ RewriterResult rewrite(DataSchema dataSchema, List<Object[]> rows);
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/ResultRewriterFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/ResultRewriterFactory.java
new file mode 100644
index 0000000000..9d3d0e45da
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/ResultRewriterFactory.java
@@ -0,0 +1,69 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.utils.rewriter;
+
+import com.google.common.collect.ImmutableList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicReference;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+public class ResultRewriterFactory {
+
+ private ResultRewriterFactory() {
+ }
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(ResultRewriterFactory.class);
+ // left blank intentionally to not load any result rewriter by default
+ static final List<String> DEFAULT_RESULT_REWRITERS_CLASS_NAMES = ImmutableList.of();
+
+ static AtomicReference<List<ResultRewriter>> _resultRewriters
+ = new AtomicReference<>(getResultRewriter(DEFAULT_RESULT_REWRITERS_CLASS_NAMES));
+
+ public static void init(String resultRewritersClassNamesStr) {
+ List<String> resultRewritersClassNames =
+ (resultRewritersClassNamesStr != null) ? Arrays.asList(resultRewritersClassNamesStr.split(","))
+ : DEFAULT_RESULT_REWRITERS_CLASS_NAMES;
+ _resultRewriters.set(getResultRewriter(resultRewritersClassNames));
+ }
+
+ public static List<ResultRewriter> getResultRewriter() {
+ return _resultRewriters.get();
+ }
+
+ private static List<ResultRewriter> getResultRewriter(List<String> resultRewriterClasses) {
+ final ImmutableList.Builder<ResultRewriter> builder = ImmutableList.builder();
+ for (String resultRewriterClassName : resultRewriterClasses) {
+ try {
+ builder.add(getResultRewriter(resultRewriterClassName));
+ } catch (Exception e) {
+ LOGGER.error("Failed to load resultRewriter: {}", resultRewriterClassName, e);
+ }
+ }
+ return builder.build();
+ }
+
+ private static ResultRewriter getResultRewriter(String resultRewriterClassName)
+ throws Exception {
+ final Class<ResultRewriter> resultRewriterClass = (Class<ResultRewriter>) Class.forName(resultRewriterClassName);
+ return (ResultRewriter) resultRewriterClass.getDeclaredConstructors()[0].newInstance();
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/RewriterResult.java b/pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/RewriterResult.java
new file mode 100644
index 0000000000..f62b4a35f9
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/RewriterResult.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.utils.rewriter;
+
+import java.util.List;
+import org.apache.pinot.common.utils.DataSchema;
+
+
+public class RewriterResult {
+ DataSchema _dataSchema;
+ List<Object[]> _rows;
+
+ public RewriterResult(DataSchema dataSchema, List<Object[]> rows) {
+ _dataSchema = dataSchema;
+ _rows = rows;
+ }
+
+ public DataSchema getDataSchema() {
+ return _dataSchema;
+ }
+
+ public List<Object[]> getRows() {
+ return _rows;
+ }
+}
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/common/datablock/DataBlockTest.java b/pinot-core/src/test/java/org/apache/pinot/core/common/datablock/DataBlockTest.java
index d71260aa8b..593fd62a69 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/common/datablock/DataBlockTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/common/datablock/DataBlockTest.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.common.datablock;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import org.apache.pinot.common.datablock.ColumnarDataBlock;
@@ -95,4 +96,49 @@ public class DataBlockTest {
public Object[][] provideTestTypeNullPercentile() {
return new Object[][]{new Object[]{0}, new Object[]{10}, new Object[]{100}};
}
+
+ /**
+ * TODO: bytes array serialization probably needs fixing.
+ */
+ @Test
+ void bytesArraySerDe() {
+ Object[] row = new Object[1];
+ row[0] = new byte[][]{new byte[]{0xD, 0xA}, new byte[]{0xD, 0xA}};
+ List<Object[]> rows = new ArrayList<>();
+ rows.add(row);
+
+ DataSchema dataSchema = new DataSchema(new String[]{"byteArray"},
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.BYTES_ARRAY});
+
+ try {
+ DataBlock dataBlock = DataBlockBuilder.buildFromRows(rows, dataSchema);
+ Assert.assertNull(dataBlock);
+ Assert.fail();
+ } catch (Exception e) {
+ Assert.assertTrue(e.toString()
+ .contains("java.lang.IllegalArgumentException: Unsupported type of value: byte[][]"));
+ }
+ }
+
+ /**
+ * TODO: empty int array deserialization is probably needs fixing.
+ */
+ @Test
+ void intArraySerDe()
+ throws IOException {
+ Object[] row = new Object[1];
+ row[0] = new int[0];
+ List<Object[]> rows = new ArrayList<>();
+ rows.add(row);
+
+ DataSchema dataSchema = new DataSchema(new String[]{"intArray"},
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT_ARRAY});
+
+ DataBlock dataBlock = DataBlockBuilder.buildFromRows(rows, dataSchema);
+ try {
+ DataBlockUtils.getDataBlock(ByteBuffer.wrap(dataBlock.toBytes())).getIntArray(0, 0);
+ } catch (Exception e) {
+ Assert.assertTrue(e.toString().contains("java.lang.NullPointerException"));
+ }
+ }
}
diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/ArgMinMaxTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/ArgMinMaxTest.java
new file mode 100644
index 0000000000..94a63b6b24
--- /dev/null
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/ArgMinMaxTest.java
@@ -0,0 +1,644 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.queries;
+
+import java.io.File;
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.commons.io.FileUtils;
+import org.apache.pinot.common.response.BrokerResponse;
+import org.apache.pinot.common.response.broker.BrokerResponseNative;
+import org.apache.pinot.common.response.broker.ResultTable;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.query.utils.rewriter.ResultRewriterFactory;
+import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader;
+import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl;
+import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader;
+import org.apache.pinot.segment.spi.ImmutableSegment;
+import org.apache.pinot.segment.spi.IndexSegment;
+import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.utils.ReadMode;
+import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
+import org.apache.pinot.sql.parsers.rewriter.QueryRewriterFactory;
+import org.testng.Assert;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNull;
+import static org.testng.Assert.assertTrue;
+import static org.testng.Assert.fail;
+
+
+/**
+ * Queries test for argMin/argMax functions.
+ */
+@SuppressWarnings({"rawtypes", "unchecked"})
+public class ArgMinMaxTest extends BaseQueriesTest {
+ private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(), "ArgMinMaxTest");
+ private static final String RAW_TABLE_NAME = "testTable";
+ private static final String SEGMENT_NAME = "testSegment";
+
+ private static final int NUM_RECORDS = 2000;
+
+ private static final String INT_COLUMN = "intColumn";
+ private static final String LONG_COLUMN = "longColumn";
+ private static final String FLOAT_COLUMN = "floatColumn";
+ private static final String DOUBLE_COLUMN = "doubleColumn";
+ private static final String MV_DOUBLE_COLUMN = "mvDoubleColumn";
+ private static final String MV_INT_COLUMN = "mvIntColumn";
+ private static final String MV_BYTES_COLUMN = "mvBytesColumn";
+ private static final String MV_STRING_COLUMN = "mvStringColumn";
+ private static final String STRING_COLUMN = "stringColumn";
+ private static final String GROUP_BY_INT_COLUMN = "groupByIntColumn";
+ private static final String GROUP_BY_MV_INT_COLUMN = "groupByMVIntColumn";
+ private static final String GROUP_BY_INT_COLUMN2 = "groupByIntColumn2";
+ private static final String BIG_DECIMAL_COLUMN = "bigDecimalColumn";
+ private static final String TIMESTAMP_COLUMN = "timestampColumn";
+ private static final String BOOLEAN_COLUMN = "booleanColumn";
+ private static final String JSON_COLUMN = "jsonColumn";
+
+ private static final Schema SCHEMA = new Schema.SchemaBuilder().addSingleValueDimension(INT_COLUMN, DataType.INT)
+ .addSingleValueDimension(LONG_COLUMN, DataType.LONG).addSingleValueDimension(FLOAT_COLUMN, DataType.FLOAT)
+ .addSingleValueDimension(DOUBLE_COLUMN, DataType.DOUBLE).addMultiValueDimension(MV_INT_COLUMN, DataType.INT)
+ .addMultiValueDimension(MV_BYTES_COLUMN, DataType.BYTES)
+ .addMultiValueDimension(MV_STRING_COLUMN, DataType.STRING)
+ .addSingleValueDimension(STRING_COLUMN, DataType.STRING)
+ .addSingleValueDimension(GROUP_BY_INT_COLUMN, DataType.INT)
+ .addMultiValueDimension(GROUP_BY_MV_INT_COLUMN, DataType.INT)
+ .addSingleValueDimension(GROUP_BY_INT_COLUMN2, DataType.INT)
+ .addSingleValueDimension(BIG_DECIMAL_COLUMN, DataType.BIG_DECIMAL)
+ .addSingleValueDimension(TIMESTAMP_COLUMN, DataType.TIMESTAMP)
+ .addSingleValueDimension(BOOLEAN_COLUMN, DataType.BOOLEAN)
+ .addMultiValueDimension(MV_DOUBLE_COLUMN, DataType.DOUBLE)
+ .addSingleValueDimension(JSON_COLUMN, DataType.JSON)
+ .build();
+ private static final TableConfig TABLE_CONFIG =
+ new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+
+ private IndexSegment _indexSegment;
+ private List<IndexSegment> _indexSegments;
+
+ @Override
+ protected String getFilter() {
+ return " WHERE intColumn >= 500";
+ }
+
+ @Override
+ protected IndexSegment getIndexSegment() {
+ return _indexSegment;
+ }
+
+ @Override
+ protected List<IndexSegment> getIndexSegments() {
+ return _indexSegments;
+ }
+
+ @BeforeClass
+ public void setUp()
+ throws Exception {
+ FileUtils.deleteDirectory(INDEX_DIR);
+
+ List<GenericRow> records = new ArrayList<>(NUM_RECORDS);
+ String[] stringSVVals = new String[]{"a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", "a11", "a22"};
+ int j = 1;
+ for (int i = 0; i < NUM_RECORDS; i++) {
+ GenericRow record = new GenericRow();
+ record.putValue(INT_COLUMN, i);
+ record.putValue(LONG_COLUMN, (long) i - NUM_RECORDS / 2);
+ record.putValue(FLOAT_COLUMN, (float) i * 0.5);
+ record.putValue(DOUBLE_COLUMN, (double) i);
+ record.putValue(MV_INT_COLUMN, Arrays.asList(i, i + 1, i + 2));
+ record.putValue(MV_BYTES_COLUMN, Arrays.asList(String.valueOf(i).getBytes(), String.valueOf(i + 1).getBytes(),
+ String.valueOf(i + 2).getBytes()));
+ record.putValue(MV_STRING_COLUMN, Arrays.asList("a" + i, "a" + i + 1, "a" + i + 2));
+ if (i < 20) {
+ record.putValue(STRING_COLUMN, stringSVVals[i % stringSVVals.length]);
+ } else {
+ record.putValue(STRING_COLUMN, "a33");
+ }
+ record.putValue(GROUP_BY_INT_COLUMN, i % 5);
+ record.putValue(GROUP_BY_MV_INT_COLUMN, Arrays.asList(i % 10, (i + 1) % 10));
+ if (i == j) {
+ j *= 2;
+ }
+ record.putValue(GROUP_BY_INT_COLUMN2, j);
+ record.putValue(BIG_DECIMAL_COLUMN, new BigDecimal(-i * i + 1200 * i));
+ record.putValue(TIMESTAMP_COLUMN, 1683138373879L - i);
+ record.putValue(BOOLEAN_COLUMN, i % 2);
+ record.putValue(MV_DOUBLE_COLUMN, Arrays.asList((double) i, (double) i * i, (double) i * i * i));
+ record.putValue(JSON_COLUMN, "{\"name\":\"John\", \"age\":" + i + ", \"car\":null}");
+ records.add(record);
+ }
+
+ SegmentGeneratorConfig segmentGeneratorConfig = new SegmentGeneratorConfig(TABLE_CONFIG, SCHEMA);
+ segmentGeneratorConfig.setTableName(RAW_TABLE_NAME);
+ segmentGeneratorConfig.setSegmentName(SEGMENT_NAME);
+ segmentGeneratorConfig.setOutDir(INDEX_DIR.getPath());
+
+ SegmentIndexCreationDriverImpl driver = new SegmentIndexCreationDriverImpl();
+ driver.init(segmentGeneratorConfig, new GenericRowRecordReader(records));
+ driver.build();
+
+ ImmutableSegment immutableSegment = ImmutableSegmentLoader.load(new File(INDEX_DIR, SEGMENT_NAME), ReadMode.mmap);
+ _indexSegment = immutableSegment;
+ _indexSegments = Arrays.asList(immutableSegment, immutableSegment);
+
+ QueryRewriterFactory.init(String.join(",", QueryRewriterFactory.DEFAULT_QUERY_REWRITERS_CLASS_NAMES)
+ + ",org.apache.pinot.sql.parsers.rewriter.ArgMinMaxRewriter");
+ ResultRewriterFactory
+ .init("org.apache.pinot.core.query.utils.rewriter.ParentAggregationResultRewriter");
+ }
+
+ @Test
+ public void invalidParamTest() {
+ String query = "SELECT arg_max(intColumn) FROM testTable";
+ try {
+ getBrokerResponse(query);
+ fail("Should have failed for invalid params");
+ } catch (Exception e) {
+ Assert.assertTrue(e.getMessage().contains("Invalid number of arguments for argmax"));
+ }
+
+ query = "SELECT arg_max() FROM testTable";
+ try {
+ getBrokerResponse(query);
+ fail("Should have failed for invalid params");
+ } catch (Exception e) {
+ Assert.assertTrue(e.getMessage().contains("Invalid number of arguments for argmax"));
+ }
+
+ query = "SELECT arg_max(mvDoubleColumn, mvDoubleColumn) FROM testTable";
+ BrokerResponse brokerResponse = getBrokerResponse(query);
+ Assert.assertTrue(brokerResponse.getProcessingExceptions().get(0).getMessage().contains(
+ "java.lang.IllegalStateException: ArgMinMax only supports single-valued measuring columns"
+ ));
+
+ query = "SELECT arg_max(jsonColumn, mvDoubleColumn) FROM testTable";
+ brokerResponse = getBrokerResponse(query);
+ Assert.assertTrue(brokerResponse.getProcessingExceptions().get(0).getMessage().contains(
+ "Cannot compute ArgMinMax measuring on non-comparable type: JSON"
+ ));
+ }
+
+ @Test
+ public void testAggregationInterSegment() {
+ // Simple inter segment aggregation test
+ String query = "SELECT arg_max(intColumn, longColumn) FROM testTable";
+
+ BrokerResponseNative brokerResponse = getBrokerResponse(query);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ List<Object[]> rows = resultTable.getRows();
+
+ assertEquals(rows.get(0)[0], 999L);
+ assertEquals(rows.get(1)[0], 999L);
+ assertEquals(rows.size(), 2);
+
+ // Inter segment data type test
+ query = "SELECT arg_max(intColumn, longColumn), arg_max(intColumn, floatColumn), "
+ + "arg_max(intColumn, doubleColumn), arg_min(intColumn, mvIntColumn), "
+ + "arg_min(intColumn, mvStringColumn), arg_min(intColumn, intColumn), "
+ + "arg_max(bigDecimalColumn, bigDecimalColumn), arg_max(bigDecimalColumn, doubleColumn),"
+ + "arg_min(timestampColumn, timestampColumn), arg_max(bigDecimalColumn, mvDoubleColumn),"
+ + "arg_max(bigDecimalColumn, jsonColumn)"
+ + " FROM testTable";
+
+ brokerResponse = getBrokerResponse(query);
+ resultTable = brokerResponse.getResultTable();
+ rows = resultTable.getRows();
+
+ assertEquals(resultTable.getDataSchema().getColumnName(0), "argmax(intColumn,longColumn)");
+ assertEquals(resultTable.getDataSchema().getColumnName(1), "argmax(intColumn,floatColumn)");
+ assertEquals(resultTable.getDataSchema().getColumnName(2), "argmax(intColumn,doubleColumn)");
+ assertEquals(resultTable.getDataSchema().getColumnName(3), "argmin(intColumn,mvIntColumn)");
+ assertEquals(resultTable.getDataSchema().getColumnName(4), "argmin(intColumn,mvStringColumn)");
+ assertEquals(resultTable.getDataSchema().getColumnName(5), "argmin(intColumn,intColumn)");
+ assertEquals(resultTable.getDataSchema().getColumnName(6), "argmax(bigDecimalColumn,bigDecimalColumn)");
+ assertEquals(resultTable.getDataSchema().getColumnName(7), "argmax(bigDecimalColumn,doubleColumn)");
+ assertEquals(resultTable.getDataSchema().getColumnName(8), "argmin(timestampColumn,timestampColumn)");
+ assertEquals(resultTable.getDataSchema().getColumnName(9), "argmax(bigDecimalColumn,mvDoubleColumn)");
+ assertEquals(resultTable.getDataSchema().getColumnName(10), "argmax(bigDecimalColumn,jsonColumn)");
+
+ assertEquals(rows.size(), 2);
+ assertEquals(rows.get(0)[0], 999L);
+ assertEquals(rows.get(1)[0], 999L);
+ assertEquals(rows.get(0)[1], 999.5F);
+ assertEquals(rows.get(1)[1], 999.5F);
+ assertEquals(rows.get(0)[2], 1999D);
+ assertEquals(rows.get(1)[2], 1999D);
+ assertEquals(rows.get(0)[3], new Integer[]{0, 1, 2});
+ assertEquals(rows.get(1)[3], new Integer[]{0, 1, 2});
+ assertEquals(rows.get(0)[4], new String[]{"a0", "a01", "a02"});
+ assertEquals(rows.get(1)[4], new String[]{"a0", "a01", "a02"});
+ assertEquals(rows.get(0)[5], 0);
+ assertEquals(rows.get(1)[5], 0);
+ assertEquals(rows.get(0)[6], "360000");
+ assertEquals(rows.get(1)[6], "360000");
+ assertEquals(rows.get(0)[7], 600D);
+ assertEquals(rows.get(1)[7], 600D);
+ assertEquals(rows.get(0)[8], 1683138373879L - 1999L);
+ assertEquals(rows.get(1)[8], 1683138373879L - 1999L);
+ assertEquals(rows.get(0)[9], new Double[]{600D, 600D * 600D, 600D * 600D * 600D});
+ assertEquals(rows.get(1)[9], new Double[]{600D, 600D * 600D, 600D * 600D * 600D});
+ assertEquals(rows.get(0)[10], "{\"name\":\"John\",\"age\":600,\"car\":null}");
+ assertEquals(rows.get(1)[10], "{\"name\":\"John\",\"age\":600,\"car\":null}");
+
+ // Inter segment data type test for boolean column
+ query = "SELECT arg_max(booleanColumn, booleanColumn) FROM testTable";
+
+ brokerResponse = getBrokerResponse(query);
+ resultTable = brokerResponse.getResultTable();
+ rows = resultTable.getRows();
+
+ assertEquals(rows.size(), 2000);
+ for (int i = 0; i < 2000; i++) {
+ assertEquals(rows.get(i)[0], 1);
+ }
+
+ // Inter segment mix aggregation function with different result length
+ // Inter segment string column comparison test, with dedupe
+ query = "SELECT sum(intColumn), argmin(stringColumn, doubleColumn), argmin(stringColumn, stringColumn), "
+ + "argmin(stringColumn, doubleColumn, doubleColumn) FROM testTable";
+
+ brokerResponse = getBrokerResponse(query);
+ resultTable = brokerResponse.getResultTable();
+ rows = resultTable.getRows();
+
+ assertEquals(rows.size(), 4);
+
+ assertEquals(rows.get(0)[0], 7996000D);
+ assertEquals(rows.get(0)[1], 8D);
+ assertEquals(rows.get(0)[2], "a11");
+ assertEquals(rows.get(0)[3], 8D);
+
+ assertEquals(rows.get(1)[0], 7996000D);
+ assertEquals(rows.get(1)[1], 18D);
+ assertEquals(rows.get(1)[2], "a11");
+ assertEquals(rows.get(1)[3], 8D);
+
+ assertEquals(rows.get(2)[0], 7996000D);
+ assertEquals(rows.get(2)[1], 8D);
+ assertEquals(rows.get(2)[2], "a11");
+ assertNull(rows.get(2)[3]);
+
+ assertEquals(rows.get(3)[0], 7996000D);
+ assertEquals(rows.get(3)[1], 18D);
+ assertEquals(rows.get(3)[2], "a11");
+ assertNull(rows.get(3)[3]);
+
+ // Test transformation function inside argmax/argmin, for both projection and measuring
+ // the max of 3000x-x^2 is 2250000, which is the max of 3000x-x^2
+ query = "SELECT sum(intColumn), argmax(3000 * doubleColumn - intColumn * intColumn, doubleColumn),"
+ + "argmax(3000 * doubleColumn - intColumn * intColumn, 3000 * doubleColumn - intColumn * intColumn),"
+ + "argmax(3000 * doubleColumn - intColumn * intColumn, doubleColumn), "
+ + "argmin(replace(stringColumn, \'a\', \'bb\'), replace(stringColumn, \'a\', \'bb\'))"
+ + "FROM testTable";
+
+ brokerResponse = getBrokerResponse(query);
+ resultTable = brokerResponse.getResultTable();
+ rows = resultTable.getRows();
+
+ assertEquals(rows.size(), 4);
+
+ assertEquals(rows.get(0)[0], 7996000D);
+ assertEquals(rows.get(0)[1], 1500D);
+ assertEquals(rows.get(0)[2], 2250000D);
+ assertEquals(rows.get(0)[3], "bb11");
+ assertEquals(rows.get(1)[0], 7996000D);
+ assertEquals(rows.get(1)[1], 1500D);
+ assertEquals(rows.get(1)[2], 2250000D);
+ assertEquals(rows.get(1)[3], "bb11");
+ assertEquals(rows.get(2)[0], 7996000D);
+ assertNull(rows.get(2)[1]);
+ assertEquals(rows.get(2)[3], "bb11");
+ assertEquals(rows.get(3)[0], 7996000D);
+ assertNull(rows.get(3)[1]);
+ assertEquals(rows.get(3)[3], "bb11");
+
+ // Inter segment mix aggregation function with CASE statement
+ query = "SELECT argmin(CASE WHEN stringColumn = 'a33' THEN 'b' WHEN stringColumn = 'a22' THEN 'a' ELSE 'c' END"
+ + ", stringColumn), argmin(CASE WHEN stringColumn = 'a33' THEN 'b' WHEN stringColumn = 'a22' THEN 'a' "
+ + "ELSE 'c' END, CASE WHEN stringColumn = 'a33' THEN 'b' WHEN stringColumn = 'a22' THEN 'a' ELSE 'c' END) "
+ + "FROM testTable";
+
+ brokerResponse = getBrokerResponse(query);
+ resultTable = brokerResponse.getResultTable();
+ rows = resultTable.getRows();
+
+ assertEquals(rows.size(), 4);
+
+ for (int i = 0; i < 4; i++) {
+ assertEquals(rows.get(i)[0], "a22");
+ assertEquals(rows.get(i)[1], "a");
+ }
+
+ // TODO: The following query throws an exception,
+ // requires fix for multi-value bytes column serialization in DataBlock
+ query = "SELECT arg_min(intColumn, mvBytesColumn) FROM testTable";
+
+ try {
+ brokerResponse = getBrokerResponse(query);
+ fail("remove this test case, now mvBytesColumn works correctly in serialization");
+ } catch (Exception e) {
+ assertTrue(e.getMessage()
+ .contains("java.lang.IllegalArgumentException: Unsupported type of value: byte[][]"));
+ }
+ }
+
+ @Test
+ public void testAggregationDedupe() {
+ // Inter segment dedupe test1 without dedupe
+ String query = "SELECT "
+ + "argmin(booleanColumn, bigDecimalColumn, intColumn) FROM testTable WHERE doubleColumn <= 1200";
+
+ BrokerResponseNative brokerResponse = getBrokerResponse(query);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ List<Object[]> rows = resultTable.getRows();
+
+ assertEquals(rows.size(), 4);
+
+ assertEquals(rows.get(0)[0], 0);
+ assertEquals(rows.get(1)[0], 1200);
+ assertEquals(rows.get(2)[0], 0);
+ assertEquals(rows.get(3)[0], 1200);
+
+ // test1, with dedupe
+ query = "SELECT "
+ + "argmin(booleanColumn, bigDecimalColumn, doubleColumn, intColumn) FROM testTable WHERE doubleColumn <= 1200";
+
+ brokerResponse = getBrokerResponse(query);
+ resultTable = brokerResponse.getResultTable();
+ rows = resultTable.getRows();
+
+ assertEquals(rows.size(), 2);
+
+ assertEquals(rows.get(0)[0], 0);
+ assertEquals(rows.get(1)[0], 0);
+
+ // test2, with dedupe
+ query = "SELECT "
+ + "argmin(booleanColumn, bigDecimalColumn, 0-doubleColumn, intColumn) FROM testTable WHERE doubleColumn <= "
+ + "1200";
+
+ brokerResponse = getBrokerResponse(query);
+ resultTable = brokerResponse.getResultTable();
+ rows = resultTable.getRows();
+
+ assertEquals(rows.size(), 2);
+
+ assertEquals(rows.get(0)[0], 1200);
+ assertEquals(rows.get(1)[0], 1200);
+ }
+
+ @Test
+ public void testEmptyAggregation() {
+ // Inter segment mix aggregation with no documents after filtering
+ String query =
+ "SELECT arg_max(intColumn, longColumn), argmin(CASE WHEN stringColumn = 'a33' THEN 'b' "
+ + "WHEN stringColumn = 'a22' THEN 'a' ELSE 'c' END"
+ + ", stringColumn) FROM testTable where intColumn > 10000";
+
+ BrokerResponseNative brokerResponse = getBrokerResponse(query);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 1);
+ assertNull(rows.get(0)[0]);
+ assertNull(rows.get(0)[1]);
+ assertEquals(resultTable.getDataSchema().getColumnName(0), "argmax(intColumn,longColumn)");
+ assertEquals(resultTable.getDataSchema().getColumnName(1),
+ "argmin(case(equals(stringColumn,'a33'),equals(stringColumn,'a22'),'b','a','c'),stringColumn)");
+ Assert.assertEquals(resultTable.getDataSchema().getColumnDataType(0), DataSchema.ColumnDataType.STRING);
+ Assert.assertEquals(resultTable.getDataSchema().getColumnDataType(1), DataSchema.ColumnDataType.STRING);
+ }
+
+ @Test
+ public void testGroupByInterSegment() {
+ // Simple inter segment group by
+ String query = "SELECT groupByIntColumn, arg_max(intColumn, longColumn) FROM testTable GROUP BY groupByIntColumn";
+
+ BrokerResponseNative brokerResponse = getBrokerResponse(query);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ List<Object[]> rows = resultTable.getRows();
+
+ assertEquals(rows.size(), 10);
+
+ for (int i = 0; i < 10; i++) {
+ int group = ((i + 2) / 2) % 5;
+ assertEquals(rows.get(i)[0], group);
+ assertEquals(rows.get(i)[1], 995L + group);
+ }
+
+ // Simple inter segment group by with limit
+ query =
+ "SELECT groupByIntColumn2, arg_max(longColumn, doubleColumn) FROM testTable GROUP BY groupByIntColumn2 ORDER "
+ + "BY groupByIntColumn2 LIMIT 15";
+
+ brokerResponse = getBrokerResponse(query);
+ resultTable = brokerResponse.getResultTable();
+ rows = resultTable.getRows();
+
+ assertEquals(rows.size(), 24);
+
+ for (int i = 0; i < 22; i++) {
+ double group = Math.pow(2, i / 2);
+ assertEquals(rows.get(i)[0], (int) group);
+ assertEquals(rows.get(i)[1], group - 1);
+ }
+
+ assertEquals(rows.get(22)[0], 2048);
+ assertEquals(rows.get(22)[1], 1999D);
+
+ assertEquals(rows.get(23)[0], 2048);
+ assertEquals(rows.get(23)[1], 1999D);
+
+ // MV inter segment group by
+ query = "SELECT groupByMVIntColumn, arg_min(intColumn, doubleColumn) FROM testTable GROUP BY groupByMVIntColumn";
+
+ brokerResponse = getBrokerResponse(query);
+ resultTable = brokerResponse.getResultTable();
+ rows = resultTable.getRows();
+
+ assertEquals(rows.size(), 20);
+
+ for (int i = 0; i < 18; i++) {
+ int group = i / 2 + 1;
+ assertEquals(rows.get(i)[0], group);
+ assertEquals(rows.get(i)[1], (double) group - 1);
+ }
+
+ assertEquals(rows.get(18)[0], 0);
+ assertEquals(rows.get(18)[1], 0D);
+
+ assertEquals(rows.get(19)[0], 0);
+ assertEquals(rows.get(19)[1], 0D);
+
+ // MV inter segment group by with projection on MV column
+ query = "SELECT groupByMVIntColumn, arg_min(intColumn, mvIntColumn), "
+ + "arg_max(intColumn, mvStringColumn) FROM testTable GROUP BY groupByMVIntColumn";
+
+ brokerResponse = getBrokerResponse(query);
+ resultTable = brokerResponse.getResultTable();
+ rows = resultTable.getRows();
+ assertEquals(rows.size(), 20);
+
+ for (int i = 0; i < 18; i++) {
+ int group = i / 2 + 1;
+ assertEquals(rows.get(i)[0], group);
+ assertEquals(rows.get(i)[1], new Object[]{group - 1, group, group + 1});
+ assertEquals(rows.get(i)[2], new Object[]{"a199" + group, "a199" + group + 1, "a199" + group + 2});
+ }
+
+ assertEquals(rows.get(18)[0], 0);
+ assertEquals(rows.get(18)[1], new Object[]{0, 1, 2});
+ assertEquals(rows.get(18)[2], new Object[]{"a1999", "a19991", "a19992"});
+ }
+
+ @Test
+ public void testGroupByInterSegmentWithValueIn() {
+ // MV VALUE_IN segment group by
+ String query =
+ "SELECT stringColumn, arg_min(intColumn, VALUE_IN(mvIntColumn,16,17,18,19,20,21,22,23,24,25,26,27)), "
+ + "arg_max(intColumn, VALUE_IN(mvIntColumn,16,17,18,19,20,21,22,23,24,25,26,27)) "
+ + "FROM testTable WHERE mvIntColumn in (16,17,18,19,20,21,22,23,24,25,26,27) GROUP BY stringColumn";
+
+ BrokerResponse brokerResponse = getBrokerResponse(query);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ List<Object[]> rows = resultTable.getRows();
+ assertEquals(rows.size(), 14);
+ assertEquals(rows.get(4)[0], "a33");
+ assertEquals(rows.get(4)[1], new Object[]{20, 21, 22});
+ assertEquals(rows.get(4)[2], new Object[]{27});
+
+ // TODO: The following query works because whenever we find an empty array in the result, we use null
+ // (see ArgMinMaxProjectionValSetWrapper). Ideally, we should be able to serialize empty array.
+ // requires fix for empty int arrays ser/de in DataBlock
+ query =
+ "SELECT stringColumn, arg_min(intColumn, VALUE_IN(mvIntColumn,16,17,18,19,20,21,22,23,24,25,26,27)), "
+ + "arg_max(intColumn, VALUE_IN(mvIntColumn,16,17,18,19,20,21,22,23,24,25,26,27)) "
+ + "FROM testTable GROUP BY stringColumn";
+
+ brokerResponse = getBrokerResponse(query);
+ resultTable = brokerResponse.getResultTable();
+ rows = resultTable.getRows();
+ assertEquals(rows.size(), 20);
+ assertEquals(rows.get(8)[0], "a33");
+ assertEquals(rows.get(8)[1], new Object[]{20, 21, 22});
+ assertEquals(rows.get(8)[2], new Object[]{});
+ }
+
+ @Test
+ public void explainPlanTest() {
+ String query = "EXPLAIN PLAN FOR SELECT groupByMVIntColumn, arg_min(intColumn, mvIntColumn), "
+ + "arg_min(intColumn, doubleColumn, mvStringColumn) FROM testTable GROUP BY groupByMVIntColumn";
+ BrokerResponseNative brokerResponse = getBrokerResponse(query);
+ Object groupByExplainPlan = brokerResponse.getResultTable().getRows().get(3)[0];
+ Assert.assertTrue(groupByExplainPlan
+ .toString().contains("child_argMin('0', mvIntColumn, intColumn, mvIntColumn)"));
+ Assert.assertTrue(groupByExplainPlan
+ .toString()
+ .contains("child_argMin('1', mvStringColumn, intColumn, doubleColumn, mvStringColumn)"));
+ Assert.assertTrue(groupByExplainPlan
+ .toString().contains("parent_argMin('0', '1', intColumn, mvIntColumn)"));
+ Assert.assertTrue(groupByExplainPlan
+ .toString().contains("parent_argMin('1', '2', intColumn, doubleColumn, mvStringColumn)"));
+ }
+
+ @Test
+ public void testEmptyGroupByInterSegment() {
+ // Simple inter segment group by with no documents after filtering
+ String query = "SELECT groupByIntColumn, arg_max(intColumn, longColumn) FROM testTable "
+ + " where intColumn > 10000 GROUP BY groupByIntColumn";
+
+ BrokerResponseNative brokerResponse = getBrokerResponse(query);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ List<Object[]> rows = resultTable.getRows();
+
+ assertEquals(resultTable.getDataSchema().getColumnName(0), "groupByIntColumn");
+ assertEquals(resultTable.getDataSchema().getColumnName(1), "argmax(intColumn,longColumn)");
+ assertEquals(resultTable.getDataSchema().getColumnDataType(0), DataSchema.ColumnDataType.INT);
+ assertEquals(resultTable.getDataSchema().getColumnDataType(1), DataSchema.ColumnDataType.STRING);
+ assertEquals(rows.size(), 0);
+
+ // Simple inter segment group by with no documents after filtering
+ query = "SELECT groupByIntColumn, arg_max(intColumn, longColumn), sum(longColumn), arg_min(intColumn, longColumn)"
+ + " FROM testTable "
+ + " where intColumn > 10000 GROUP BY groupByIntColumn";
+
+ brokerResponse = getBrokerResponse(query);
+ resultTable = brokerResponse.getResultTable();
+ rows = resultTable.getRows();
+ assertEquals(resultTable.getDataSchema().getColumnName(0), "groupByIntColumn");
+ assertEquals(resultTable.getDataSchema().getColumnName(1), "argmax(intColumn,longColumn)");
+ assertEquals(resultTable.getDataSchema().getColumnName(2), "sum(longColumn)");
+ assertEquals(resultTable.getDataSchema().getColumnName(3), "argmin(intColumn,longColumn)");
+ assertEquals(resultTable.getDataSchema().getColumnDataType(0), DataSchema.ColumnDataType.INT);
+ assertEquals(resultTable.getDataSchema().getColumnDataType(1), DataSchema.ColumnDataType.STRING);
+ assertEquals(resultTable.getDataSchema().getColumnDataType(0), DataSchema.ColumnDataType.INT);
+ assertEquals(resultTable.getDataSchema().getColumnDataType(1), DataSchema.ColumnDataType.STRING);
+ assertEquals(rows.size(), 0);
+ }
+
+ @Test
+ public void testAlias() {
+ // Using argmin/argmax with alias will fail, since the alias will not be resolved by the rewriter
+ try {
+ String query = "SELECT groupByIntColumn, arg_max(intColumn, longColumn) AS"
+ + " argmax1 FROM testTable GROUP BY groupByIntColumn";
+ BrokerResponseNative brokerResponse = getBrokerResponse(query);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ List<Object[]> rows = resultTable.getRows();
+ fail();
+ } catch (Exception e) {
+ assertTrue(e.getMessage().contains("Aggregation function: argmax(intColumn,longColumn) "
+ + "is only supported in selection without alias."));
+ }
+ }
+
+ @Test
+ public void testOrderBy() {
+ // Using argmin/argmax with order by will fail, since the ordering on a multi-row projection is not well-defined
+ try {
+ String query = "SELECT groupByIntColumn, arg_max(intColumn, longColumn) FROM testTable "
+ + "GROUP BY groupByIntColumn ORDER BY arg_max(intColumn, longColumn)";
+ BrokerResponseNative brokerResponse = getBrokerResponse(query);
+ ResultTable resultTable = brokerResponse.getResultTable();
+ List<Object[]> rows = resultTable.getRows();
+ fail();
+ } catch (Exception e) {
+ assertTrue(e.getMessage().contains("Aggregation function: argmax(intColumn,longColumn) "
+ + "is only supported in selection without alias."));
+ }
+ }
+
+ @AfterClass
+ public void tearDown()
+ throws IOException {
+ _indexSegment.destroy();
+ FileUtils.deleteDirectory(INDEX_DIR);
+ }
+}
diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/ResultRewriterRegressionTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/ResultRewriterRegressionTest.java
new file mode 100644
index 0000000000..0e4cd336a5
--- /dev/null
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/ResultRewriterRegressionTest.java
@@ -0,0 +1,69 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.pinot.queries;
+
+import org.apache.pinot.core.query.utils.rewriter.ResultRewriterFactory;
+import org.apache.pinot.sql.parsers.rewriter.QueryRewriterFactory;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+
+/**
+ * Regression test for queries with result rewriter.
+ */
+public class ResultRewriterRegressionTest {
+
+ @Test
+ public static class StatsQueriesRegressionTest extends StatisticalQueriesTest {
+ @BeforeClass
+ public void setupRewriter()
+ throws Exception {
+ QueryRewriterFactory.init(String.join(",", QueryRewriterFactory.DEFAULT_QUERY_REWRITERS_CLASS_NAMES)
+ + ",org.apache.pinot.sql.parsers.rewriter.ArgMinMaxRewriter");
+ ResultRewriterFactory
+ .init("org.apache.pinot.core.query.utils.rewriter.ParentAggregationResultRewriter");
+ }
+ }
+
+ @Test
+ public static class HistogramQueriesRegressionTest extends HistogramQueriesTest {
+ @BeforeClass
+ public void setupRewriter()
+ throws Exception {
+ QueryRewriterFactory.init(String.join(",", QueryRewriterFactory.DEFAULT_QUERY_REWRITERS_CLASS_NAMES)
+ + ",org.apache.pinot.sql.parsers.rewriter.ArgMinMaxRewriter");
+ ResultRewriterFactory
+ .init("org.apache.pinot.core.query.utils.rewriter.ParentAggregationResultRewriter");
+ }
+ }
+
+ @Test
+ public static class InterSegmentAggregationMultiValueQueriesRegressionTest
+ extends InterSegmentAggregationMultiValueQueriesTest {
+ @BeforeClass
+ public void setupRewriter()
+ throws Exception {
+ QueryRewriterFactory.init(String.join(",", QueryRewriterFactory.DEFAULT_QUERY_REWRITERS_CLASS_NAMES)
+ + ",org.apache.pinot.sql.parsers.rewriter.ArgMinMaxRewriter");
+ ResultRewriterFactory
+ .init("org.apache.pinot.core.query.utils.rewriter.ParentAggregationResultRewriter");
+ }
+ }
+}
diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
index 7c355d0057..0a2a22dd00 100644
--- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
+++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
@@ -23,6 +23,7 @@ import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang.StringUtils;
+import org.apache.pinot.spi.utils.CommonConstants;
/**
@@ -94,7 +95,15 @@ public enum AggregationFunctionType {
// boolean aggregate functions
BOOLAND("boolAnd"),
- BOOLOR("boolOr");
+ BOOLOR("boolOr"),
+
+ // argMin and argMax
+ ARGMIN("argMin"),
+ ARGMAX("argMax"),
+ PARENTARGMIN(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX + ARGMIN.getName()),
+ PARENTARGMAX(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX + ARGMAX.getName()),
+ CHILDARGMIN(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX + ARGMIN.getName()),
+ CHILDARGMAX(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX + ARGMAX.getName());
private static final Set<String> NAMES = Arrays.stream(values()).flatMap(func -> Stream.of(func.name(),
func.getName(), func.getName().toLowerCase())).collect(Collectors.toSet());
diff --git a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
index 123a789bcb..5e03e40217 100644
--- a/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
+++ b/pinot-spi/src/main/java/org/apache/pinot/spi/utils/CommonConstants.java
@@ -303,6 +303,8 @@ public class CommonConstants {
"pinot.broker.instance.enableThreadAllocatedBytesMeasurement";
public static final boolean DEFAULT_ENABLE_THREAD_CPU_TIME_MEASUREMENT = false;
public static final boolean DEFAULT_THREAD_ALLOCATED_BYTES_MEASUREMENT = false;
+ public static final String CONFIG_OF_BROKER_RESULT_REWRITER_CLASS_NAMES
+ = "pinot.broker.result.rewriter.class.names";
public static class Request {
public static final String SQL = "sql";
@@ -972,4 +974,11 @@ public class CommonConstants {
public static class IdealState {
public static final String HYBRID_TABLE_TIME_BOUNDARY = "HYBRID_TABLE_TIME_BOUNDARY";
}
+
+ public static class RewriterConstants {
+ public static final String PARENT_AGGREGATION_NAME_PREFIX = "parent";
+ public static final String CHILD_AGGREGATION_NAME_PREFIX = "child";
+ public static final String CHILD_AGGREGATION_SEPERATOR = "@";
+ public static final String CHILD_KEY_SEPERATOR = "_";
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org