You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ja...@apache.org on 2020/05/12 21:02:49 UTC
[incubator-pinot] branch master updated: Clean up
AggregationFunctionContext and use TransformExpressionTree as the key in
the blockValSetMap passed to the AggregationFunctions (#5364)
This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 8b0089f Clean up AggregationFunctionContext and use TransformExpressionTree as the key in the blockValSetMap passed to the AggregationFunctions (#5364)
8b0089f is described below
commit 8b0089f4e8f8d323abbdfb0d8dfd0c79e49b41c2
Author: Xiaotian (Jackie) Jiang <17...@users.noreply.github.com>
AuthorDate: Tue May 12 14:02:38 2020 -0700
Clean up AggregationFunctionContext and use TransformExpressionTree as the key in the blockValSetMap passed to the AggregationFunctions (#5364)
- Clean up all the usage of AggregationFunctionContext to directly use AggregationFunction
- Construct the AggregationFunctions and Group-by Expressions at planning phase and pass them to Operator and Executor to save the extra expression compilation
- Use TransformExpressionTree as the key in the blockValSetMap passed to the AggregationFunctions
- The benefit of this is to save the redundant string conversion, and more efficient hashCode() and equals()
- The keys of the blockValSetMap should be the same as AggregationFunction.getInputExpressions()
- The only exception is CountAggregationFunction with Star-Tree where there is a single entry in blockValSetMap (column "*")
- Add base implementation of AggregationFunction: BaseSingleExpressionAggregationFunction for aggregation functions on single expressions
- For PERCENTILE group aggregation functions, support using the second arguments to pass in percentile (e.g. PERCENTILE(column, 99), PERCENTILETDIGEST(column, 90))
- Enhance Star-Tree Aggregation/Group-by Executor to handle the column name conversion so that AggregationFunctionColumnPair is transparent to the AggregationFunction
BACKWARD-INCOMPATIBLE CHANGE:
The following APIs are changed in AggregationFunction (use TransformExpressionTree instead of String as the key of blockValSetMap):
void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<TransformExpressionTree, BlockValSet> blockValSetMap);
void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map<TransformExpressionTree, BlockValSet> blockValSetMap);
void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, Map<TransformExpressionTree, BlockValSet> blockValSetMap);
---
.../requesthandler/BaseBrokerRequestHandler.java | 15 +-
.../common/function/AggregationFunctionType.java | 12 +-
.../request/transform/TransformExpressionTree.java | 23 +-
.../parsers/PinotQuery2BrokerRequestConverter.java | 2 +-
.../core/common/datatable/DataTableUtils.java | 23 +-
.../apache/pinot/core/data/table/BaseTable.java | 31 +--
.../core/data/table/ConcurrentIndexedTable.java | 10 +-
.../apache/pinot/core/data/table/IndexedTable.java | 7 +-
.../pinot/core/data/table/SimpleIndexedTable.java | 10 +-
.../apache/pinot/core/data/table/TableResizer.java | 40 ++-
.../core/operator/CombineGroupByOperator.java | 13 +-
.../operator/CombineGroupByOrderByOperator.java | 10 +-
.../operator/blocks/IntermediateResultsBlock.java | 48 ++--
.../operator/query/AggregationGroupByOperator.java | 31 ++-
.../query/AggregationGroupByOrderByOperator.java | 64 +++--
.../core/operator/query/AggregationOperator.java | 14 +-
.../query/DictionaryBasedAggregationOperator.java | 49 ++--
.../query/MetadataBasedAggregationOperator.java | 32 +--
.../plan/AggregationGroupByOrderByPlanNode.java | 73 +++--
.../core/plan/AggregationGroupByPlanNode.java | 62 +++--
.../pinot/core/plan/AggregationPlanNode.java | 49 ++--
.../plan/DictionaryBasedAggregationPlanNode.java | 23 +-
.../plan/MetadataBasedAggregationPlanNode.java | 33 +--
.../core/plan/maker/InstancePlanMakerImplV2.java | 7 +-
.../aggregation/AggregationFunctionContext.java | 85 ------
.../aggregation/DefaultAggregationExecutor.java | 66 ++---
.../core/query/aggregation/DistinctTable.java | 8 +-
.../aggregation/function/AggregationFunction.java | 7 +-
.../function/AggregationFunctionFactory.java | 102 ++++---
.../function/AggregationFunctionUtils.java | 191 ++++++-------
.../function/AvgAggregationFunction.java | 44 +--
.../function/AvgMVAggregationFunction.java | 18 +-
.../BaseSingleInputAggregationFunction.java | 57 ++++
.../function/CountAggregationFunction.java | 32 +--
.../function/CountMVAggregationFunction.java | 29 +-
.../function/DistinctAggregationFunction.java | 39 ++-
.../function/DistinctCountAggregationFunction.java | 42 +--
.../DistinctCountHLLAggregationFunction.java | 42 +--
.../DistinctCountHLLMVAggregationFunction.java | 18 +-
.../DistinctCountMVAggregationFunction.java | 18 +-
.../DistinctCountRawHLLAggregationFunction.java | 41 +--
.../DistinctCountRawHLLMVAggregationFunction.java | 6 +-
...istinctCountThetaSketchAggregationFunction.java | 53 ++--
.../function/FastHLLAggregationFunction.java | 42 +--
.../function/MaxAggregationFunction.java | 42 +--
.../function/MaxMVAggregationFunction.java | 18 +-
.../function/MinAggregationFunction.java | 42 +--
.../function/MinMVAggregationFunction.java | 18 +-
.../function/MinMaxRangeAggregationFunction.java | 41 +--
.../function/MinMaxRangeMVAggregationFunction.java | 18 +-
.../function/PercentileAggregationFunction.java | 48 +---
.../function/PercentileEstAggregationFunction.java | 43 +--
.../PercentileEstMVAggregationFunction.java | 28 +-
.../function/PercentileMVAggregationFunction.java | 28 +-
.../PercentileTDigestAggregationFunction.java | 43 +--
.../PercentileTDigestMVAggregationFunction.java | 28 +-
.../function/SumAggregationFunction.java | 41 +--
.../function/SumMVAggregationFunction.java | 18 +-
.../groupby/DefaultGroupByExecutor.java | 102 +++----
.../query/aggregation/groupby/GroupByExecutor.java | 3 +-
.../query/reduce/AggregationDataTableReducer.java | 47 ++--
.../pinot/core/query/reduce/CombineService.java | 24 +-
.../core/query/reduce/ComparisonFunction.java | 5 +-
.../query/reduce/DistinctDataTableReducer.java | 13 +-
.../core/query/reduce/GroupByDataTableReducer.java | 16 +-
.../core/query/request/ServerQueryRequest.java | 2 +-
.../apache/pinot/core/startree/StarTreeUtils.java | 32 +--
.../executor/StarTreeAggregationExecutor.java | 55 +---
.../startree/executor/StarTreeGroupByExecutor.java | 70 ++---
.../startree/plan/StarTreeTransformPlanNode.java | 10 +-
.../pinot/core/data/table/IndexedTableTest.java | 128 ++++-----
.../pinot/core/data/table/TableResizerTest.java | 305 +++++++++------------
.../function/AggregationFunctionFactoryTest.java | 9 +-
.../pinot/core/startree/v2/BaseStarTreeV2Test.java | 34 +--
...terSegmentAggregationMultiValueQueriesTest.java | 122 +++++----
...erSegmentAggregationSingleValueQueriesTest.java | 41 +--
...terSegmentResultTableMultiValueQueriesTest.java | 79 +++---
...erSegmentResultTableSingleValueQueriesTest.java | 78 +++---
.../DefaultAggregationExecutorTest.java | 33 +--
.../apache/pinot/perf/BenchmarkCombineGroupBy.java | 76 ++---
.../apache/pinot/perf/BenchmarkIndexedTable.java | 61 ++---
81 files changed, 1430 insertions(+), 1992 deletions(-)
diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java
index c9f5e29..6a00878 100644
--- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java
+++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java
@@ -456,13 +456,9 @@ public abstract class BaseBrokerRequestHandler implements BrokerRequestHandler {
for (AggregationInfo info : brokerRequest.getAggregationsInfo()) {
if (!info.getAggregationType().equalsIgnoreCase(AggregationFunctionType.COUNT.getName())) {
// Always read from backward compatible api in AggregationFunctionUtils.
- List<String> expressions = AggregationFunctionUtils.getAggregationExpressions(info);
-
- List<String> newExpressions = new ArrayList<>(expressions.size());
- for (String expression : expressions) {
- newExpressions.add(fixColumnNameCase(actualTableName, expression));
- }
- info.setExpressions(newExpressions);
+ List<String> arguments = AggregationFunctionUtils.getArguments(info);
+ arguments.replaceAll(e -> fixColumnNameCase(actualTableName, e));
+ info.setExpressions(arguments);
}
}
if (brokerRequest.isSetGroupBy()) {
@@ -720,11 +716,10 @@ public abstract class BaseBrokerRequestHandler implements BrokerRequestHandler {
throw new UnsupportedOperationException("DISTINCT with GROUP BY is currently not supported");
}
if (brokerRequest.isSetOrderBy()) {
- List<String> columns = AggregationFunctionUtils.getAggregationExpressions(aggregationInfo);
- Set<String> set = new HashSet<>(columns);
+ Set<String> expressionSet = new HashSet<>(AggregationFunctionUtils.getArguments(aggregationInfo));
List<SelectionSort> orderByColumns = brokerRequest.getOrderBy();
for (SelectionSort selectionSort : orderByColumns) {
- if (!set.contains(selectionSort.getColumn())) {
+ if (!expressionSet.contains(selectionSort.getColumn())) {
throw new UnsupportedOperationException(
"ORDER By should be only on some/all of the columns passed as arguments to DISTINCT");
}
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java b/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
index d3cef28..af31639 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/function/AggregationFunctionType.java
@@ -76,17 +76,17 @@ public enum AggregationFunctionType {
String upperCaseFunctionName = functionName.toUpperCase();
if (upperCaseFunctionName.startsWith("PERCENTILE")) {
String remainingFunctionName = upperCaseFunctionName.substring(10);
- if (remainingFunctionName.matches("\\d+")) {
+ if (remainingFunctionName.isEmpty() || remainingFunctionName.matches("\\d+")) {
return PERCENTILE;
- } else if (remainingFunctionName.matches("EST\\d+")) {
+ } else if (remainingFunctionName.equals("EST") || remainingFunctionName.matches("EST\\d+")) {
return PERCENTILEEST;
- } else if (remainingFunctionName.matches("TDIGEST\\d+")) {
+ } else if (remainingFunctionName.equals("TDIGEST") || remainingFunctionName.matches("TDIGEST\\d+")) {
return PERCENTILETDIGEST;
- } else if (remainingFunctionName.matches("\\d+MV")) {
+ } else if (remainingFunctionName.equals("MV") || remainingFunctionName.matches("\\d+MV")) {
return PERCENTILEMV;
- } else if (remainingFunctionName.matches("EST\\d+MV")) {
+ } else if (remainingFunctionName.equals("ESTMV") || remainingFunctionName.matches("EST\\d+MV")) {
return PERCENTILEESTMV;
- } else if (remainingFunctionName.matches("TDIGEST\\d+MV")) {
+ } else if (remainingFunctionName.equals("TDIGESTMV") || remainingFunctionName.matches("TDIGEST\\d+MV")) {
return PERCENTILETDIGESTMV;
} else {
throw new IllegalArgumentException("Invalid aggregation function name: " + functionName);
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/request/transform/TransformExpressionTree.java b/pinot-common/src/main/java/org/apache/pinot/common/request/transform/TransformExpressionTree.java
index 5c0a515..527df87 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/request/transform/TransformExpressionTree.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/request/transform/TransformExpressionTree.java
@@ -22,14 +22,13 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
-import javax.annotation.Nonnull;
-import org.apache.pinot.spi.utils.EqualityUtils;
+import javax.annotation.Nullable;
import org.apache.pinot.pql.parsers.Pql2Compiler;
import org.apache.pinot.pql.parsers.pql2.ast.AstNode;
import org.apache.pinot.pql.parsers.pql2.ast.FunctionCallAstNode;
import org.apache.pinot.pql.parsers.pql2.ast.IdentifierAstNode;
import org.apache.pinot.pql.parsers.pql2.ast.LiteralAstNode;
-import org.apache.pinot.pql.parsers.pql2.ast.StringLiteralAstNode;
+import org.apache.pinot.spi.utils.EqualityUtils;
/**
@@ -37,7 +36,7 @@ import org.apache.pinot.pql.parsers.pql2.ast.StringLiteralAstNode;
* <ul>
* <li>A TransformExpressionTree node has either transform function or a column name, or a literal.</li>
* <li>Leaf nodes either have column name or literal, whereas non-leaf nodes have transform function.</li>
- * <li>Transform function in non-leaf nodes is applied to its children nodes.</li>
+ * <li>Transform function is applied to its children.</li>
* </ul>
*/
public class TransformExpressionTree {
@@ -66,10 +65,9 @@ public class TransformExpressionTree {
} else if (astNode instanceof FunctionCallAstNode) {
// UDF expression
return standardizeExpression(((FunctionCallAstNode) astNode).getExpression());
- } else if (astNode instanceof StringLiteralAstNode) {
- // Treat string as column name
- // NOTE: this is for backward-compatibility
- return ((StringLiteralAstNode) astNode).getText();
+ } else if (astNode instanceof LiteralAstNode) {
+ // Literal
+ return ((LiteralAstNode) astNode).getValueAsString();
} else {
throw new IllegalStateException("Cannot get standard expression from " + astNode.getClass().getSimpleName());
}
@@ -106,6 +104,13 @@ public class TransformExpressionTree {
}
}
+ public TransformExpressionTree(ExpressionType expressionType, String value,
+ @Nullable List<TransformExpressionTree> children) {
+ _expressionType = expressionType;
+ _value = value;
+ _children = children;
+ }
+
/**
* Returns the expression type of the node, which can be one of the following:
* <ul>
@@ -168,7 +173,7 @@ public class TransformExpressionTree {
*
* @param columns Output columns
*/
- public void getColumns(@Nonnull Set<String> columns) {
+ public void getColumns(Set<String> columns) {
if (_expressionType == ExpressionType.IDENTIFIER) {
columns.add(_value);
} else if (_children != null) {
diff --git a/pinot-common/src/main/java/org/apache/pinot/pql/parsers/PinotQuery2BrokerRequestConverter.java b/pinot-common/src/main/java/org/apache/pinot/pql/parsers/PinotQuery2BrokerRequestConverter.java
index 6ce2f16..1ce9a06 100644
--- a/pinot-common/src/main/java/org/apache/pinot/pql/parsers/PinotQuery2BrokerRequestConverter.java
+++ b/pinot-common/src/main/java/org/apache/pinot/pql/parsers/PinotQuery2BrokerRequestConverter.java
@@ -226,7 +226,7 @@ public class PinotQuery2BrokerRequestConverter {
private String getColumnExpression(Expression functionParam) {
switch (functionParam.getType()) {
case LITERAL:
- return functionParam.getLiteral().getStringValue();
+ return functionParam.getLiteral().getFieldValue().toString();
case IDENTIFIER:
return functionParam.getIdentifier().getName();
case FUNCTION:
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableUtils.java
index 2122779..1574a24 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableUtils.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/common/datatable/DataTableUtils.java
@@ -26,7 +26,6 @@ import org.apache.pinot.common.request.BrokerRequest;
import org.apache.pinot.common.request.Selection;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.DataTable;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.util.QueryOptions;
@@ -101,9 +100,8 @@ public class DataTableUtils {
}
// Aggregation query.
- AggregationFunctionContext[] aggregationFunctionContexts =
- AggregationFunctionUtils.getAggregationFunctionContexts(brokerRequest);
- int numAggregations = aggregationFunctionContexts.length;
+ AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
+ int numAggregations = aggregationFunctions.length;
if (brokerRequest.isSetGroupBy()) {
// Aggregation group-by query.
@@ -121,9 +119,9 @@ public class DataTableUtils {
columnDataTypes[index] = DataSchema.ColumnDataType.STRING;
index++;
}
- for (AggregationFunctionContext aggregationFunctionContext : aggregationFunctionContexts) {
- columnNames[index] = aggregationFunctionContext.getResultColumnName();
- AggregationFunction aggregationFunction = aggregationFunctionContext.getAggregationFunction();
+ for (AggregationFunction aggregationFunction : aggregationFunctions) {
+ // NOTE: Use AggregationFunction.getResultColumnName() for SQL format response
+ columnNames[index] = aggregationFunction.getResultColumnName();
columnDataTypes[index] = aggregationFunction.getIntermediateResultColumnType();
index++;
}
@@ -137,9 +135,10 @@ public class DataTableUtils {
// Build the data table.
DataTableBuilder dataTableBuilder = new DataTableBuilder(new DataSchema(columnNames, columnDataTypes));
- for (AggregationFunctionContext aggregationFunctionContext : aggregationFunctionContexts) {
+ for (AggregationFunction aggregationFunction : aggregationFunctions) {
dataTableBuilder.startRow();
- dataTableBuilder.setColumn(0, aggregationFunctionContext.getAggregationColumnName());
+ // NOTE: For backward-compatibility, use AggregationFunction.getColumnName() for PQL format response
+ dataTableBuilder.setColumn(0, aggregationFunction.getColumnName());
dataTableBuilder.setColumn(1, Collections.emptyMap());
dataTableBuilder.finishRow();
}
@@ -152,9 +151,9 @@ public class DataTableUtils {
DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[numAggregations];
Object[] aggregationResults = new Object[numAggregations];
for (int i = 0; i < numAggregations; i++) {
- AggregationFunctionContext aggregationFunctionContext = aggregationFunctionContexts[i];
- aggregationColumnNames[i] = aggregationFunctionContext.getAggregationColumnName();
- AggregationFunction aggregationFunction = aggregationFunctionContext.getAggregationFunction();
+ AggregationFunction aggregationFunction = aggregationFunctions[i];
+ // NOTE: For backward-compatibility, use AggregationFunction.getColumnName() for aggregation only query
+ aggregationColumnNames[i] = aggregationFunction.getColumnName();
columnDataTypes[i] = aggregationFunction.getIntermediateResultColumnType();
aggregationResults[i] =
aggregationFunction.extractAggregationResult(aggregationFunction.createAggregationResultHolder());
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/table/BaseTable.java b/pinot-core/src/main/java/org/apache/pinot/core/data/table/BaseTable.java
index 03d8b3b..3ad90ac 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/data/table/BaseTable.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/data/table/BaseTable.java
@@ -21,22 +21,20 @@ package org.apache.pinot.core.data.table;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.collections.CollectionUtils;
-import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
/**
* Base abstract implementation of Table
*/
public abstract class BaseTable implements Table {
-
- final AggregationFunction[] _aggregationFunctions;
- final int _numAggregations;
+ // TODO: After fixing the DistinctTable logic, make it final
protected DataSchema _dataSchema;
- final int _numColumns;
+ protected final int _numColumns;
+ protected final AggregationFunction[] _aggregationFunctions;
+ protected final int _numAggregations;
// the capacity we need to trim to
protected int _capacity;
@@ -46,34 +44,27 @@ public abstract class BaseTable implements Table {
protected boolean _isOrderBy;
protected TableResizer _tableResizer;
- private final List<AggregationInfo> _aggregationInfos;
-
/**
* Initializes the variables and comparators needed for the table
*/
- public BaseTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy, int capacity) {
+ public BaseTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions, List<SelectionSort> orderBy,
+ int capacity) {
_dataSchema = dataSchema;
_numColumns = dataSchema.size();
-
- _numAggregations = aggregationInfos.size();
- _aggregationFunctions = new AggregationFunction[_numAggregations];
- for (int i = 0; i < _numAggregations; i++) {
- _aggregationFunctions[i] =
- AggregationFunctionUtils.getAggregationFunctionContext(aggregationInfos.get(i)).getAggregationFunction();
- }
-
- _aggregationInfos = aggregationInfos;
+ _aggregationFunctions = aggregationFunctions;
+ _numAggregations = aggregationFunctions.length;
addCapacityAndOrderByInfo(orderBy, capacity);
}
protected void addCapacityAndOrderByInfo(List<SelectionSort> orderBy, int capacity) {
_isOrderBy = CollectionUtils.isNotEmpty(orderBy);
if (_isOrderBy) {
- _tableResizer = new TableResizer(_dataSchema, _aggregationInfos, orderBy);
+ _tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, orderBy);
// TODO: tune these numbers and come up with a better formula (github ISSUE-4801)
// Based on the capacity and maxCapacity, the resizer will smartly choose to evict/retain recors from the PQ
- if (capacity <= 100_000) { // Capacity is small, make a very large buffer. Make PQ of records to retain, during resize
+ if (capacity
+ <= 100_000) { // Capacity is small, make a very large buffer. Make PQ of records to retain, during resize
_maxCapacity = 1_000_000;
} else { // Capacity is large, make buffer only slightly bigger. Make PQ of records to evict, during resize
_maxCapacity = (int) (capacity * 1.2);
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/table/ConcurrentIndexedTable.java b/pinot-core/src/main/java/org/apache/pinot/core/data/table/ConcurrentIndexedTable.java
index aa52218..3c25ef2 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/data/table/ConcurrentIndexedTable.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/data/table/ConcurrentIndexedTable.java
@@ -27,9 +27,9 @@ import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
-import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -52,13 +52,13 @@ public class ConcurrentIndexedTable extends IndexedTable {
/**
* Initializes the data structures needed for this Table
* @param dataSchema data schema of the record's keys and values
- * @param aggregationInfos aggregation infos for the aggregations in record's values
+ * @param aggregationFunctions aggregation functions for the record's values
* @param orderBy list of {@link SelectionSort} defining the order by
* @param capacity the capacity of the table
*/
- public ConcurrentIndexedTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy,
- int capacity) {
- super(dataSchema, aggregationInfos, orderBy, capacity);
+ public ConcurrentIndexedTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions,
+ List<SelectionSort> orderBy, int capacity) {
+ super(dataSchema, aggregationFunctions, orderBy, capacity);
_lookupMap = new ConcurrentHashMap<>();
_readWriteLock = new ReentrantReadWriteLock();
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/table/IndexedTable.java b/pinot-core/src/main/java/org/apache/pinot/core/data/table/IndexedTable.java
index 0f29f7a..cb2dc33 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/data/table/IndexedTable.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/data/table/IndexedTable.java
@@ -20,9 +20,9 @@ package org.apache.pinot.core.data.table;
import java.util.Arrays;
import java.util.List;
-import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
/**
@@ -36,8 +36,9 @@ public abstract class IndexedTable extends BaseTable {
/**
* Initializes the variables and comparators needed for the table
*/
- IndexedTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy, int capacity) {
- super(dataSchema, aggregationInfos, orderBy, capacity);
+ IndexedTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions, List<SelectionSort> orderBy,
+ int capacity) {
+ super(dataSchema, aggregationFunctions, orderBy, capacity);
_numKeyColumns = dataSchema.size() - _numAggregations;
_keyExtractor = new KeyExtractor(_numKeyColumns);
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/table/SimpleIndexedTable.java b/pinot-core/src/main/java/org/apache/pinot/core/data/table/SimpleIndexedTable.java
index cb28c08..ffdf103 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/data/table/SimpleIndexedTable.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/data/table/SimpleIndexedTable.java
@@ -24,9 +24,9 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.annotation.concurrent.NotThreadSafe;
-import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -48,13 +48,13 @@ public class SimpleIndexedTable extends IndexedTable {
/**
* Initializes the data structures needed for this Table
* @param dataSchema data schema of the record's keys and values
- * @param aggregationInfos aggregation infos for the aggregations in record'd values
+ * @param aggregationFunctions aggregation functions for the record's values
* @param orderBy list of {@link SelectionSort} defining the order by
* @param capacity the capacity of the table
*/
- public SimpleIndexedTable(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy,
- int capacity) {
- super(dataSchema, aggregationInfos, orderBy, capacity);
+ public SimpleIndexedTable(DataSchema dataSchema, AggregationFunction[] aggregationFunctions,
+ List<SelectionSort> orderBy, int capacity) {
+ super(dataSchema, aggregationFunctions, orderBy, capacity);
_lookupMap = new HashMap<>();
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
index 19b8b4c..1fa1deb 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/data/table/TableResizer.java
@@ -31,11 +31,9 @@ import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.function.Function;
-import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
/**
@@ -48,13 +46,13 @@ public class TableResizer {
private Comparator<Record> _recordComparator;
protected int _numOrderBy;
- TableResizer(DataSchema dataSchema, List<AggregationInfo> aggregationInfos, List<SelectionSort> orderBy) {
+ TableResizer(DataSchema dataSchema, AggregationFunction[] aggregationFunctions, List<SelectionSort> orderBy) {
// NOTE: the assumption here is that the key columns will appear before the aggregation columns in the data schema
// This is handled in the only in the AggregationGroupByOrderByOperator for now
int numColumns = dataSchema.size();
- int numAggregations = aggregationInfos.size();
+ int numAggregations = aggregationFunctions.length;
int numKeyColumns = numColumns - numAggregations;
Map<String, Integer> columnIndexMap = new HashMap<>();
@@ -63,10 +61,7 @@ public class TableResizer {
String columnName = dataSchema.getColumnName(i);
columnIndexMap.put(columnName, i);
if (i >= numKeyColumns) {
- AggregationInfo aggregationInfo = aggregationInfos.get(i - numKeyColumns);
- AggregationFunction aggregationFunction =
- AggregationFunctionUtils.getAggregationFunctionContext(aggregationInfo).getAggregationFunction();
- aggregationColumnToFunction.put(columnName, aggregationFunction);
+ aggregationColumnToFunction.put(columnName, aggregationFunctions[i - numKeyColumns]);
}
}
@@ -109,7 +104,8 @@ public class TableResizer {
};
} else {
// For cases where the entire Record is unique and is treated as a key
- Preconditions.checkState(numKeyColumns == numColumns, "number of key columns should be equal to total number of columns");
+ Preconditions
+ .checkState(numKeyColumns == numColumns, "number of key columns should be equal to total number of columns");
int[] orderByIndexes = new int[_numOrderBy];
boolean[] orderByAsc = new boolean[_numOrderBy];
for (int i = 0; i < _numOrderBy; i++) {
@@ -194,7 +190,7 @@ public class TableResizer {
return priorityQueue;
}
- private List<Record> sortRecordsMap(Map<Key, Record> recordsMap) {
+ private List<Record> sortRecordsMap(Map<Key, Record> recordsMap) {
int numRecords = recordsMap.size();
List<Record> sortedRecords = new ArrayList<>(numRecords);
List<IntermediateRecord> intermediateRecords = new ArrayList<>(numRecords);
@@ -321,7 +317,6 @@ public class TableResizer {
}
}
-
/********************************************************
* *
* Resize functions for Set based table implementation *
@@ -342,9 +337,10 @@ public class TableResizer {
Object[] values1 = record1.getValues();
Object[] values2 = record2.getValues();
for (int i = 0; i < _numOrderBy; i++) {
- Comparable valueToCompare1 = (Comparable)values1[_orderByColumnIndexes[i]];
- Comparable valueToCompare2 = (Comparable)values2[_orderByColumnIndexes[i]];
- int result = _orderByAsc[i] ? valueToCompare1.compareTo(valueToCompare2) : valueToCompare2.compareTo(valueToCompare1);
+ Comparable valueToCompare1 = (Comparable) values1[_orderByColumnIndexes[i]];
+ Comparable valueToCompare2 = (Comparable) values2[_orderByColumnIndexes[i]];
+ int result =
+ _orderByAsc[i] ? valueToCompare1.compareTo(valueToCompare2) : valueToCompare2.compareTo(valueToCompare1);
if (result != 0) {
return result;
}
@@ -359,14 +355,16 @@ public class TableResizer {
if (numRecordsToEvict < trimToSize) {
// num records to evict is smaller than num records to retain
// make PQ of records to evict
- PriorityQueue<Record> priorityQueue = buildPriorityQueueFromRecordSet(numRecordsToEvict, recordSet, _recordComparator);
+ PriorityQueue<Record> priorityQueue =
+ buildPriorityQueueFromRecordSet(numRecordsToEvict, recordSet, _recordComparator);
for (Record recordToEvict : priorityQueue) {
recordSet.remove(recordToEvict);
}
} else {
// num records to retain is smaller than num records to evict
// make PQ of records to retain
- PriorityQueue<Record> priorityQueue = buildPriorityQueueFromRecordSet(trimToSize, recordSet, _recordComparator.reversed());
+ PriorityQueue<Record> priorityQueue =
+ buildPriorityQueueFromRecordSet(trimToSize, recordSet, _recordComparator.reversed());
ObjectOpenHashSet<Record> recordsToRetain = new ObjectOpenHashSet<>(priorityQueue.size());
for (Record recordToRetain : priorityQueue) {
recordsToRetain.add(recordToRetain);
@@ -376,9 +374,7 @@ public class TableResizer {
}
}
- private PriorityQueue<Record> buildPriorityQueueFromRecordSet(
- int size,
- Set<Record> recordSet,
+ private PriorityQueue<Record> buildPriorityQueueFromRecordSet(int size, Set<Record> recordSet,
Comparator<Record> comparator) {
PriorityQueue<Record> priorityQueue = new PriorityQueue<>(size, comparator);
for (Record record : recordSet) {
@@ -416,7 +412,8 @@ public class TableResizer {
// num records to evict is smaller than num records to retain
if (numRecordsToEvict > 0) {
// make PQ of records to evict
- PriorityQueue<Record> priorityQueue = buildPriorityQueueFromRecordSet(numRecordsToEvict, recordSet, _recordComparator);
+ PriorityQueue<Record> priorityQueue =
+ buildPriorityQueueFromRecordSet(numRecordsToEvict, recordSet, _recordComparator);
for (Record recordToEvict : priorityQueue) {
recordSet.remove(recordToEvict);
}
@@ -424,7 +421,8 @@ public class TableResizer {
return sortRecordSet(recordSet);
} else {
// make PQ of records to retain
- PriorityQueue<Record> priorityQueue = buildPriorityQueueFromRecordSet(numRecordsToRetain, recordSet, _recordComparator.reversed());
+ PriorityQueue<Record> priorityQueue =
+ buildPriorityQueueFromRecordSet(numRecordsToRetain, recordSet, _recordComparator.reversed());
// use PQ to get sorted list
Record[] sortedArray = new Record[numRecordsToRetain];
while (!priorityQueue.isEmpty()) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOperator.java
index 0508142..a7bb5c2 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOperator.java
@@ -37,7 +37,6 @@ import org.apache.pinot.common.request.BrokerRequest;
import org.apache.pinot.common.response.ProcessingException;
import org.apache.pinot.core.common.Operator;
import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
@@ -108,13 +107,8 @@ public class CombineGroupByOperator extends BaseOperator<IntermediateResultsBloc
AtomicInteger numGroups = new AtomicInteger();
ConcurrentLinkedQueue<ProcessingException> mergedProcessingExceptions = new ConcurrentLinkedQueue<>();
- AggregationFunctionContext[] aggregationFunctionContexts =
- AggregationFunctionUtils.getAggregationFunctionContexts(_brokerRequest);
- int numAggregationFunctions = aggregationFunctionContexts.length;
- AggregationFunction[] aggregationFunctions = new AggregationFunction[numAggregationFunctions];
- for (int i = 0; i < numAggregationFunctions; i++) {
- aggregationFunctions[i] = aggregationFunctionContexts[i].getAggregationFunction();
- }
+ AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(_brokerRequest);
+ int numAggregationFunctions = aggregationFunctions.length;
// We use a CountDownLatch to track if all Futures are finished by the query timeout, and cancel the unfinished
// futures (try to interrupt the execution if it already started).
@@ -205,8 +199,7 @@ public class CombineGroupByOperator extends BaseOperator<IntermediateResultsBloc
new AggregationGroupByTrimmingService(aggregationFunctions, (int) _brokerRequest.getGroupBy().getTopN());
List<Map<String, Object>> trimmedResults =
aggregationGroupByTrimmingService.trimIntermediateResultsMap(resultsMap);
- IntermediateResultsBlock mergedBlock =
- new IntermediateResultsBlock(aggregationFunctionContexts, trimmedResults, true);
+ IntermediateResultsBlock mergedBlock = new IntermediateResultsBlock(aggregationFunctions, trimmedResults, true);
// Set the processing exceptions.
if (!mergedProcessingExceptions.isEmpty()) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOrderByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOrderByOperator.java
index 5b724f1..c45f8ba 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOrderByOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/CombineGroupByOrderByOperator.java
@@ -41,6 +41,8 @@ import org.apache.pinot.core.data.table.ConcurrentIndexedTable;
import org.apache.pinot.core.data.table.Key;
import org.apache.pinot.core.data.table.Record;
import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
import org.apache.pinot.core.query.exception.EarlyTerminationException;
@@ -98,7 +100,8 @@ public class CombineGroupByOrderByOperator extends BaseOperator<IntermediateResu
*/
@Override
protected IntermediateResultsBlock getNextBlock() {
- int numAggregationFunctions = _brokerRequest.getAggregationsInfoSize();
+ AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(_brokerRequest);
+ int numAggregationFunctions = aggregationFunctions.length;
int numGroupBy = _brokerRequest.getGroupBy().getExpressionsSize();
int numColumns = numGroupBy + numAggregationFunctions;
ConcurrentLinkedQueue<ProcessingException> mergedProcessingExceptions = new ConcurrentLinkedQueue<>();
@@ -137,8 +140,9 @@ public class CombineGroupByOrderByOperator extends BaseOperator<IntermediateResu
try {
if (_dataSchema == null) {
_dataSchema = intermediateResultsBlock.getDataSchema();
- _indexedTable = new ConcurrentIndexedTable(_dataSchema, _brokerRequest.getAggregationsInfo(),
- _brokerRequest.getOrderBy(), _indexedTableCapacity);
+ _indexedTable =
+ new ConcurrentIndexedTable(_dataSchema, aggregationFunctions, _brokerRequest.getOrderBy(),
+ _indexedTableCapacity);
}
} finally {
_initLock.unlock();
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/IntermediateResultsBlock.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/IntermediateResultsBlock.java
index 882c8b2..63efe7c 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/IntermediateResultsBlock.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/IntermediateResultsBlock.java
@@ -39,7 +39,7 @@ import org.apache.pinot.core.common.datatable.DataTableBuilder;
import org.apache.pinot.core.common.datatable.DataTableImplV2;
import org.apache.pinot.core.data.table.Record;
import org.apache.pinot.core.data.table.Table;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
import org.apache.pinot.core.query.selection.SelectionOperatorUtils;
import org.apache.pinot.spi.utils.ByteArray;
@@ -51,7 +51,7 @@ import org.apache.pinot.spi.utils.ByteArray;
public class IntermediateResultsBlock implements Block {
private DataSchema _dataSchema;
private Collection<Object[]> _selectionResult;
- private AggregationFunctionContext[] _aggregationFunctionContexts;
+ private AggregationFunction[] _aggregationFunctions;
private List<Object> _aggregationResult;
private AggregationGroupByResult _aggregationGroupByResult;
private List<Map<String, Object>> _combinedAggregationGroupByResult;
@@ -80,9 +80,9 @@ public class IntermediateResultsBlock implements Block {
* <p>For aggregation group-by, the result is a list of maps from group keys to aggregation values.
*/
@SuppressWarnings("unchecked")
- public IntermediateResultsBlock(AggregationFunctionContext[] aggregationFunctionContexts, List aggregationResult,
+ public IntermediateResultsBlock(AggregationFunction[] aggregationFunctions, List aggregationResult,
boolean isGroupBy) {
- _aggregationFunctionContexts = aggregationFunctionContexts;
+ _aggregationFunctions = aggregationFunctions;
if (isGroupBy) {
_combinedAggregationGroupByResult = aggregationResult;
} else {
@@ -93,18 +93,18 @@ public class IntermediateResultsBlock implements Block {
/**
* Constructor for aggregation group-by result with {@link AggregationGroupByResult}.
*/
- public IntermediateResultsBlock(AggregationFunctionContext[] aggregationFunctionContexts,
+ public IntermediateResultsBlock(AggregationFunction[] aggregationFunctions,
@Nullable AggregationGroupByResult aggregationGroupByResults) {
- _aggregationFunctionContexts = aggregationFunctionContexts;
+ _aggregationFunctions = aggregationFunctions;
_aggregationGroupByResult = aggregationGroupByResults;
}
/**
* Constructor for aggregation group-by order-by result with {@link AggregationGroupByResult}.
*/
- public IntermediateResultsBlock(AggregationFunctionContext[] aggregationFunctionContexts,
+ public IntermediateResultsBlock(AggregationFunction[] aggregationFunctions,
@Nullable AggregationGroupByResult aggregationGroupByResults, DataSchema dataSchema) {
- _aggregationFunctionContexts = aggregationFunctionContexts;
+ _aggregationFunctions = aggregationFunctions;
_aggregationGroupByResult = aggregationGroupByResults;
_dataSchema = dataSchema;
}
@@ -134,7 +134,7 @@ public class IntermediateResultsBlock implements Block {
return _dataSchema;
}
- public void setDataSchema(@Nullable DataSchema dataSchema) {
+ public void setDataSchema(DataSchema dataSchema) {
_dataSchema = dataSchema;
}
@@ -143,17 +143,17 @@ public class IntermediateResultsBlock implements Block {
return _selectionResult;
}
- public void setSelectionResult(@Nullable Collection<Object[]> rowEventsSet) {
+ public void setSelectionResult(Collection<Object[]> rowEventsSet) {
_selectionResult = rowEventsSet;
}
@Nullable
- public AggregationFunctionContext[] getAggregationFunctionContexts() {
- return _aggregationFunctionContexts;
+ public AggregationFunction[] getAggregationFunctions() {
+ return _aggregationFunctions;
}
- public void setAggregationFunctionContexts(AggregationFunctionContext[] aggregationFunctionContexts) {
- _aggregationFunctionContexts = aggregationFunctionContexts;
+ public void setAggregationFunctions(AggregationFunction[] aggregationFunctions) {
+ _aggregationFunctions = aggregationFunctions;
}
@Nullable
@@ -161,7 +161,7 @@ public class IntermediateResultsBlock implements Block {
return _aggregationResult;
}
- public void setAggregationResults(@Nullable List<Object> aggregationResults) {
+ public void setAggregationResults(List<Object> aggregationResults) {
_aggregationResult = aggregationResults;
}
@@ -175,7 +175,7 @@ public class IntermediateResultsBlock implements Block {
return _processingExceptions;
}
- public void setProcessingExceptions(@Nullable List<ProcessingException> processingExceptions) {
+ public void setProcessingExceptions(List<ProcessingException> processingExceptions) {
_processingExceptions = processingExceptions;
}
@@ -322,14 +322,14 @@ public class IntermediateResultsBlock implements Block {
private DataTable getAggregationResultDataTable()
throws Exception {
- // Extract each aggregation column name and type from aggregation function context.
- int numAggregationFunctions = _aggregationFunctionContexts.length;
+ // Extract result column name and type from each aggregation function
+ int numAggregationFunctions = _aggregationFunctions.length;
String[] columnNames = new String[numAggregationFunctions];
ColumnDataType[] columnDataTypes = new ColumnDataType[numAggregationFunctions];
for (int i = 0; i < numAggregationFunctions; i++) {
- AggregationFunctionContext aggregationFunctionContext = _aggregationFunctionContexts[i];
- columnNames[i] = aggregationFunctionContext.getAggregationColumnName();
- columnDataTypes[i] = aggregationFunctionContext.getAggregationFunction().getIntermediateResultColumnType();
+ AggregationFunction aggregationFunction = _aggregationFunctions[i];
+ columnNames[i] = aggregationFunction.getColumnName();
+ columnDataTypes[i] = aggregationFunction.getIntermediateResultColumnType();
}
// Build the data table.
@@ -364,11 +364,11 @@ public class IntermediateResultsBlock implements Block {
// Build the data table.
DataTableBuilder dataTableBuilder = new DataTableBuilder(new DataSchema(columnNames, columnDataTypes));
- int numAggregationFunctions = _aggregationFunctionContexts.length;
+ int numAggregationFunctions = _aggregationFunctions.length;
for (int i = 0; i < numAggregationFunctions; i++) {
dataTableBuilder.startRow();
- AggregationFunctionContext aggregationFunctionContext = _aggregationFunctionContexts[i];
- dataTableBuilder.setColumn(0, aggregationFunctionContext.getAggregationColumnName());
+ AggregationFunction aggregationFunction = _aggregationFunctions[i];
+ dataTableBuilder.setColumn(0, aggregationFunction.getColumnName());
dataTableBuilder.setColumn(1, _combinedAggregationGroupByResult.get(i));
dataTableBuilder.finishRow();
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOperator.java
index 58ae7b0..9236530 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOperator.java
@@ -18,26 +18,27 @@
*/
package org.apache.pinot.core.operator.query;
-import org.apache.pinot.common.request.GroupBy;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.operator.BaseOperator;
import org.apache.pinot.core.operator.ExecutionStatistics;
import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
import org.apache.pinot.core.operator.blocks.TransformBlock;
import org.apache.pinot.core.operator.transform.TransformOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
import org.apache.pinot.core.query.aggregation.groupby.GroupByExecutor;
import org.apache.pinot.core.startree.executor.StarTreeGroupByExecutor;
/**
- * The <code>AggregationOperator</code> class provides the operator for aggregation group-by query on a single segment.
+ * The <code>AggregationGroupByOperator</code> class provides the operator for aggregation group-by query on a single
+ * segment.
*/
public class AggregationGroupByOperator extends BaseOperator<IntermediateResultsBlock> {
private static final String OPERATOR_NAME = "AggregationGroupByOperator";
- private final AggregationFunctionContext[] _functionContexts;
- private final GroupBy _groupBy;
+ private final AggregationFunction[] _aggregationFunctions;
+ private final TransformExpressionTree[] _groupByExpressions;
private final int _maxInitialResultHolderCapacity;
private final int _numGroupsLimit;
private final TransformOperator _transformOperator;
@@ -46,11 +47,11 @@ public class AggregationGroupByOperator extends BaseOperator<IntermediateResults
private int _numDocsScanned = 0;
- public AggregationGroupByOperator(AggregationFunctionContext[] functionContexts, GroupBy groupBy,
- int maxInitialResultHolderCapacity, int numGroupsLimit, TransformOperator transformOperator, long numTotalDocs,
- boolean useStarTree) {
- _functionContexts = functionContexts;
- _groupBy = groupBy;
+ public AggregationGroupByOperator(AggregationFunction[] aggregationFunctions,
+ TransformExpressionTree[] groupByExpressions, int maxInitialResultHolderCapacity, int numGroupsLimit,
+ TransformOperator transformOperator, long numTotalDocs, boolean useStarTree) {
+ _aggregationFunctions = aggregationFunctions;
+ _groupByExpressions = groupByExpressions;
_maxInitialResultHolderCapacity = maxInitialResultHolderCapacity;
_numGroupsLimit = numGroupsLimit;
_transformOperator = transformOperator;
@@ -64,12 +65,12 @@ public class AggregationGroupByOperator extends BaseOperator<IntermediateResults
GroupByExecutor groupByExecutor;
if (_useStarTree) {
groupByExecutor =
- new StarTreeGroupByExecutor(_functionContexts, _groupBy, _maxInitialResultHolderCapacity, _numGroupsLimit,
- _transformOperator);
+ new StarTreeGroupByExecutor(_aggregationFunctions, _groupByExpressions, _maxInitialResultHolderCapacity,
+ _numGroupsLimit, _transformOperator);
} else {
groupByExecutor =
- new DefaultGroupByExecutor(_functionContexts, _groupBy, _maxInitialResultHolderCapacity, _numGroupsLimit,
- _transformOperator);
+ new DefaultGroupByExecutor(_aggregationFunctions, _groupByExpressions, _maxInitialResultHolderCapacity,
+ _numGroupsLimit, _transformOperator);
}
TransformBlock transformBlock;
while ((transformBlock = _transformOperator.nextBlock()) != null) {
@@ -78,7 +79,7 @@ public class AggregationGroupByOperator extends BaseOperator<IntermediateResults
}
// Build intermediate result block based on aggregation group-by result from the executor
- return new IntermediateResultsBlock(_functionContexts, groupByExecutor.getResult());
+ return new IntermediateResultsBlock(_aggregationFunctions, groupByExecutor.getResult());
}
@Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOrderByOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOrderByOperator.java
index a2f840d..abd26c3 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOrderByOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationGroupByOrderByOperator.java
@@ -18,7 +18,6 @@
*/
package org.apache.pinot.core.operator.query;
-import org.apache.pinot.common.request.GroupBy;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.operator.BaseOperator;
@@ -26,35 +25,36 @@ import org.apache.pinot.core.operator.ExecutionStatistics;
import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
import org.apache.pinot.core.operator.blocks.TransformBlock;
import org.apache.pinot.core.operator.transform.TransformOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
import org.apache.pinot.core.query.aggregation.groupby.GroupByExecutor;
import org.apache.pinot.core.startree.executor.StarTreeGroupByExecutor;
/**
- * The <code>AggregationOperator</code> class provides the operator for aggregation group-by query on a single segment.
+ * The <code>AggregationGroupByOrderByOperator</code> class provides the operator for aggregation group-by query on a
+ * single segment.
*/
public class AggregationGroupByOrderByOperator extends BaseOperator<IntermediateResultsBlock> {
private static final String OPERATOR_NAME = "AggregationGroupByOrderByOperator";
private final DataSchema _dataSchema;
- private final AggregationFunctionContext[] _functionContexts;
- private final GroupBy _groupBy;
+ private final AggregationFunction[] _aggregationFunctions;
+ private final TransformExpressionTree[] _groupByExpressions;
private final int _maxInitialResultHolderCapacity;
private final int _numGroupsLimit;
private final TransformOperator _transformOperator;
private final long _numTotalDocs;
private final boolean _useStarTree;
- private int _numDocsScanned;
+ private int _numDocsScanned = 0;
- public AggregationGroupByOrderByOperator(AggregationFunctionContext[] functionContexts, GroupBy groupBy,
- int maxInitialResultHolderCapacity, int numGroupsLimit, TransformOperator transformOperator, long numTotalDocs,
- boolean useStarTree) {
- _functionContexts = functionContexts;
- _groupBy = groupBy;
+ public AggregationGroupByOrderByOperator(AggregationFunction[] aggregationFunctions,
+ TransformExpressionTree[] groupByExpressions, int maxInitialResultHolderCapacity, int numGroupsLimit,
+ TransformOperator transformOperator, long numTotalDocs, boolean useStarTree) {
+ _aggregationFunctions = aggregationFunctions;
+ _groupByExpressions = groupByExpressions;
_maxInitialResultHolderCapacity = maxInitialResultHolderCapacity;
_numGroupsLimit = numGroupsLimit;
_transformOperator = transformOperator;
@@ -62,30 +62,28 @@ public class AggregationGroupByOrderByOperator extends BaseOperator<Intermediate
_useStarTree = useStarTree;
// NOTE: The indexedTable expects that the the data schema will have group by columns before aggregation columns
- int numColumns = groupBy.getExpressionsSize() + _functionContexts.length;
+ int numGroupByExpressions = groupByExpressions.length;
+ int numAggregationFunctions = aggregationFunctions.length;
+ int numColumns = numGroupByExpressions + numAggregationFunctions;
String[] columnNames = new String[numColumns];
DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[numColumns];
- // extract column names and data types for group by keys
- int index = 0;
- for (String groupByColumn : groupBy.getExpressions()) {
- columnNames[index] = groupByColumn;
- TransformExpressionTree expression = TransformExpressionTree.compileToExpressionTree(groupByColumn);
- columnDataTypes[index] =
- DataSchema.ColumnDataType.fromDataTypeSV(_transformOperator.getResultMetadata(expression).getDataType());
- index++;
+ // Extract column names and data types for group-by columns
+ for (int i = 0; i < numGroupByExpressions; i++) {
+ TransformExpressionTree groupByExpression = groupByExpressions[i];
+ columnNames[i] = groupByExpression.toString();
+ columnDataTypes[i] = DataSchema.ColumnDataType
+ .fromDataTypeSV(_transformOperator.getResultMetadata(groupByExpression).getDataType());
}
- // extract column names and data types for aggregations
- for (AggregationFunctionContext functionContext : functionContexts) {
- columnNames[index] = functionContext.getResultColumnName();
- columnDataTypes[index] = functionContext.getAggregationFunction().getIntermediateResultColumnType();
- index++;
+ // Extract column names and data types for aggregation functions
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ AggregationFunction aggregationFunction = aggregationFunctions[i];
+ int index = numGroupByExpressions + i;
+ columnNames[index] = aggregationFunction.getResultColumnName();
+ columnDataTypes[index] = aggregationFunction.getIntermediateResultColumnType();
}
- // TODO: We need to support putting order by columns in the data schema.
- // It is possible that the order by column is not one of the group by or aggregation columns
-
_dataSchema = new DataSchema(columnNames, columnDataTypes);
}
@@ -95,12 +93,12 @@ public class AggregationGroupByOrderByOperator extends BaseOperator<Intermediate
GroupByExecutor groupByExecutor;
if (_useStarTree) {
groupByExecutor =
- new StarTreeGroupByExecutor(_functionContexts, _groupBy, _maxInitialResultHolderCapacity, _numGroupsLimit,
- _transformOperator);
+ new StarTreeGroupByExecutor(_aggregationFunctions, _groupByExpressions, _maxInitialResultHolderCapacity,
+ _numGroupsLimit, _transformOperator);
} else {
groupByExecutor =
- new DefaultGroupByExecutor(_functionContexts, _groupBy, _maxInitialResultHolderCapacity, _numGroupsLimit,
- _transformOperator);
+ new DefaultGroupByExecutor(_aggregationFunctions, _groupByExpressions, _maxInitialResultHolderCapacity,
+ _numGroupsLimit, _transformOperator);
}
TransformBlock transformBlock;
while ((transformBlock = _transformOperator.nextBlock()) != null) {
@@ -109,7 +107,7 @@ public class AggregationGroupByOrderByOperator extends BaseOperator<Intermediate
}
// Build intermediate result block based on aggregation group-by result from the executor
- return new IntermediateResultsBlock(_functionContexts, groupByExecutor.getResult(), _dataSchema);
+ return new IntermediateResultsBlock(_aggregationFunctions, groupByExecutor.getResult(), _dataSchema);
}
@Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java
index a20d39d..e29976a 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/AggregationOperator.java
@@ -24,8 +24,8 @@ import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
import org.apache.pinot.core.operator.blocks.TransformBlock;
import org.apache.pinot.core.operator.transform.TransformOperator;
import org.apache.pinot.core.query.aggregation.AggregationExecutor;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
import org.apache.pinot.core.query.aggregation.DefaultAggregationExecutor;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.startree.executor.StarTreeAggregationExecutor;
@@ -35,16 +35,16 @@ import org.apache.pinot.core.startree.executor.StarTreeAggregationExecutor;
public class AggregationOperator extends BaseOperator<IntermediateResultsBlock> {
private static final String OPERATOR_NAME = "AggregationOperator";
- private final AggregationFunctionContext[] _functionContexts;
+ private final AggregationFunction[] _aggregationFunctions;
private final TransformOperator _transformOperator;
private final long _numTotalDocs;
private final boolean _useStarTree;
private int _numDocsScanned = 0;
- public AggregationOperator(AggregationFunctionContext[] functionContexts, TransformOperator transformOperator,
+ public AggregationOperator(AggregationFunction[] aggregationFunctions, TransformOperator transformOperator,
long numTotalDocs, boolean useStarTree) {
- _functionContexts = functionContexts;
+ _aggregationFunctions = aggregationFunctions;
_transformOperator = transformOperator;
_numTotalDocs = numTotalDocs;
_useStarTree = useStarTree;
@@ -55,9 +55,9 @@ public class AggregationOperator extends BaseOperator<IntermediateResultsBlock>
// Perform aggregation on all the transform blocks
AggregationExecutor aggregationExecutor;
if (_useStarTree) {
- aggregationExecutor = new StarTreeAggregationExecutor(_functionContexts);
+ aggregationExecutor = new StarTreeAggregationExecutor(_aggregationFunctions);
} else {
- aggregationExecutor = new DefaultAggregationExecutor(_functionContexts);
+ aggregationExecutor = new DefaultAggregationExecutor(_aggregationFunctions);
}
TransformBlock transformBlock;
while ((transformBlock = _transformOperator.nextBlock()) != null) {
@@ -66,7 +66,7 @@ public class AggregationOperator extends BaseOperator<IntermediateResultsBlock>
}
// Build intermediate result block based on aggregation result from the executor
- return new IntermediateResultsBlock(_functionContexts, aggregationExecutor.getResult(), false);
+ return new IntermediateResultsBlock(_aggregationFunctions, aggregationExecutor.getResult(), false);
}
@Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java
index 5c230da..20cd2e8 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java
@@ -21,14 +21,10 @@ package org.apache.pinot.core.operator.query;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
-import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.operator.BaseOperator;
import org.apache.pinot.core.operator.ExecutionStatistics;
import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
-import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
-import org.apache.pinot.core.query.aggregation.DoubleAggregationResultHolder;
-import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.customobject.MinMaxRangePair;
import org.apache.pinot.core.segment.index.readers.Dictionary;
@@ -47,56 +43,43 @@ import org.apache.pinot.core.segment.index.readers.Dictionary;
public class DictionaryBasedAggregationOperator extends BaseOperator<IntermediateResultsBlock> {
private static final String OPERATOR_NAME = "DictionaryBasedAggregationOperator";
- private final AggregationFunctionContext[] _aggregationFunctionContexts;
+ private final AggregationFunction[] _aggregationFunctions;
private final Map<String, Dictionary> _dictionaryMap;
- private final long _numTotalDocs;
+ private final int _numTotalDocs;
- /**
- * Constructor for the class.
- * @param aggregationFunctionContexts Aggregation function contexts.
- * @param numTotalDocs total raw docs from segmet metadata
- * @param dictionaryMap Map of column to its dictionary.
- */
- public DictionaryBasedAggregationOperator(AggregationFunctionContext[] aggregationFunctionContexts, long numTotalDocs,
- Map<String, Dictionary> dictionaryMap) {
- _aggregationFunctionContexts = aggregationFunctionContexts;
+ public DictionaryBasedAggregationOperator(AggregationFunction[] aggregationFunctions,
+ Map<String, Dictionary> dictionaryMap, int numTotalDocs) {
+ _aggregationFunctions = aggregationFunctions;
_dictionaryMap = dictionaryMap;
_numTotalDocs = numTotalDocs;
}
@Override
protected IntermediateResultsBlock getNextBlock() {
- int numAggregationFunctions = _aggregationFunctionContexts.length;
+ int numAggregationFunctions = _aggregationFunctions.length;
List<Object> aggregationResults = new ArrayList<>(numAggregationFunctions);
-
- for (AggregationFunctionContext aggregationFunctionContext : _aggregationFunctionContexts) {
- AggregationFunction function = aggregationFunctionContext.getAggregationFunction();
- AggregationFunctionType functionType = function.getType();
- String column = aggregationFunctionContext.getColumnName();
+ for (AggregationFunction aggregationFunction : _aggregationFunctions) {
+ String column = ((TransformExpressionTree) aggregationFunction.getInputExpressions().get(0)).getValue();
Dictionary dictionary = _dictionaryMap.get(column);
- AggregationResultHolder resultHolder;
- switch (functionType) {
+ switch (aggregationFunction.getType()) {
case MAX:
- resultHolder = new DoubleAggregationResultHolder(dictionary.getDoubleValue(dictionary.length() - 1));
+ aggregationResults.add(dictionary.getDoubleValue(dictionary.length() - 1));
break;
case MIN:
- resultHolder = new DoubleAggregationResultHolder(dictionary.getDoubleValue(0));
+ aggregationResults.add(dictionary.getDoubleValue(0));
break;
case MINMAXRANGE:
- double max = dictionary.getDoubleValue(dictionary.length() - 1);
- double min = dictionary.getDoubleValue(0);
- resultHolder = new ObjectAggregationResultHolder();
- resultHolder.setValue(new MinMaxRangePair(min, max));
+ aggregationResults.add(
+ new MinMaxRangePair(dictionary.getDoubleValue(0), dictionary.getDoubleValue(dictionary.length() - 1)));
break;
default:
throw new IllegalStateException(
- "Dictionary based aggregation operator does not support function type: " + functionType);
+ "Dictionary based aggregation operator does not support function type: " + aggregationFunction.getType());
}
- aggregationResults.add(function.extractAggregationResult(resultHolder));
}
// Build intermediate result block based on aggregation result from the executor.
- return new IntermediateResultsBlock(_aggregationFunctionContexts, aggregationResults, false);
+ return new IntermediateResultsBlock(_aggregationFunctions, aggregationResults, false);
}
@Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/MetadataBasedAggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/MetadataBasedAggregationOperator.java
index f999e62..514299b 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/MetadataBasedAggregationOperator.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/MetadataBasedAggregationOperator.java
@@ -27,7 +27,7 @@ import org.apache.pinot.core.common.DataSource;
import org.apache.pinot.core.operator.BaseOperator;
import org.apache.pinot.core.operator.ExecutionStatistics;
import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.segment.index.metadata.SegmentMetadata;
@@ -37,41 +37,33 @@ import org.apache.pinot.core.segment.index.metadata.SegmentMetadata;
public class MetadataBasedAggregationOperator extends BaseOperator<IntermediateResultsBlock> {
private static final String OPERATOR_NAME = "MetadataBasedAggregationOperator";
- private final AggregationFunctionContext[] _aggregationFunctionContexts;
- private final Map<String, DataSource> _dataSourceMap;
+ private final AggregationFunction[] _aggregationFunctions;
private final SegmentMetadata _segmentMetadata;
+ private final Map<String, DataSource> _dataSourceMap;
- /**
- * Constructor for the class.
- *
- * @param aggregationFunctionContexts Aggregation function contexts.
- * @param segmentMetadata Segment metadata.
- * @param dataSourceMap Map of column to its data source.
- */
- public MetadataBasedAggregationOperator(AggregationFunctionContext[] aggregationFunctionContexts,
- SegmentMetadata segmentMetadata, Map<String, DataSource> dataSourceMap) {
- _aggregationFunctionContexts = aggregationFunctionContexts;
+ public MetadataBasedAggregationOperator(AggregationFunction[] aggregationFunctions, SegmentMetadata segmentMetadata,
+ Map<String, DataSource> dataSourceMap) {
+ _aggregationFunctions = aggregationFunctions;
+ _segmentMetadata = segmentMetadata;
// Datasource is currently not used, but will start getting used as we add support for aggregation
// functions other than count(*).
_dataSourceMap = dataSourceMap;
- _segmentMetadata = segmentMetadata;
}
@Override
protected IntermediateResultsBlock getNextBlock() {
- int numAggregationFunctions = _aggregationFunctionContexts.length;
+ int numAggregationFunctions = _aggregationFunctions.length;
List<Object> aggregationResults = new ArrayList<>(numAggregationFunctions);
long numTotalDocs = _segmentMetadata.getTotalDocs();
- for (AggregationFunctionContext aggregationFunctionContext : _aggregationFunctionContexts) {
- AggregationFunctionType functionType = aggregationFunctionContext.getAggregationFunction().getType();
- Preconditions.checkState(functionType == AggregationFunctionType.COUNT,
- "Metadata based aggregation operator does not support function type: " + functionType);
+ for (AggregationFunction aggregationFunction : _aggregationFunctions) {
+ Preconditions.checkState(aggregationFunction.getType() == AggregationFunctionType.COUNT,
+ "Metadata based aggregation operator does not support function type: " + aggregationFunction.getType());
aggregationResults.add(numTotalDocs);
}
// Build intermediate result block based on aggregation result from the executor.
- return new IntermediateResultsBlock(_aggregationFunctionContexts, aggregationResults, false);
+ return new IntermediateResultsBlock(_aggregationFunctions, aggregationResults, false);
}
@Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByOrderByPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByOrderByPlanNode.java
index 70fb1e2..a22de73 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByOrderByPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByOrderByPlanNode.java
@@ -18,7 +18,6 @@
*/
package org.apache.pinot.core.plan;
-import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.pinot.common.request.AggregationInfo;
@@ -29,7 +28,7 @@ import org.apache.pinot.common.utils.request.FilterQueryTree;
import org.apache.pinot.common.utils.request.RequestUtils;
import org.apache.pinot.core.indexsegment.IndexSegment;
import org.apache.pinot.core.operator.query.AggregationGroupByOrderByOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.startree.StarTreeUtils;
import org.apache.pinot.core.startree.plan.StarTreeTransformPlanNode;
@@ -50,8 +49,9 @@ public class AggregationGroupByOrderByPlanNode implements PlanNode {
private final int _maxInitialResultHolderCapacity;
private final int _numGroupsLimit;
private final List<AggregationInfo> _aggregationInfos;
- private final AggregationFunctionContext[] _functionContexts;
+ private final AggregationFunction[] _aggregationFunctions;
private final GroupBy _groupBy;
+ private final TransformExpressionTree[] _groupByExpressions;
private final TransformPlanNode _transformPlanNode;
private final StarTreeTransformPlanNode _starTreeTransformPlanNode;
@@ -61,38 +61,51 @@ public class AggregationGroupByOrderByPlanNode implements PlanNode {
_maxInitialResultHolderCapacity = maxInitialResultHolderCapacity;
_numGroupsLimit = numGroupsLimit;
_aggregationInfos = brokerRequest.getAggregationsInfo();
- _functionContexts = AggregationFunctionUtils.getAggregationFunctionContexts(brokerRequest);
+ _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
_groupBy = brokerRequest.getGroupBy();
-
- Set<TransformExpressionTree> expressionsToTransform =
- AggregationFunctionUtils.collectExpressionsToTransform(brokerRequest, _functionContexts);
+ List<String> groupByExpressions = _groupBy.getExpressions();
+ int numGroupByExpressions = groupByExpressions.size();
+ _groupByExpressions = new TransformExpressionTree[numGroupByExpressions];
+ for (int i = 0; i < numGroupByExpressions; i++) {
+ _groupByExpressions[i] = TransformExpressionTree.compileToExpressionTree(groupByExpressions.get(i));
+ }
List<StarTreeV2> starTrees = indexSegment.getStarTrees();
if (starTrees != null) {
if (!StarTreeUtils.isStarTreeDisabled(brokerRequest)) {
- Set<AggregationFunctionColumnPair> aggregationFunctionColumnPairs = new HashSet<>();
- for (AggregationInfo aggregationInfo : _aggregationInfos) {
- aggregationFunctionColumnPairs.add(AggregationFunctionUtils.getFunctionColumnPair(aggregationInfo));
- }
- Set<TransformExpressionTree> groupByExpressions = new HashSet<>();
- for (String expression : _groupBy.getExpressions()) {
- groupByExpressions.add(TransformExpressionTree.compileToExpressionTree(expression));
+ int numAggregationFunctions = _aggregationFunctions.length;
+ AggregationFunctionColumnPair[] aggregationFunctionColumnPairs =
+ new AggregationFunctionColumnPair[numAggregationFunctions];
+ boolean hasUnsupportedAggregationFunction = false;
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ AggregationFunctionColumnPair aggregationFunctionColumnPair =
+ AggregationFunctionUtils.getAggregationFunctionColumnPair(_aggregationFunctions[i]);
+ if (aggregationFunctionColumnPair != null) {
+ aggregationFunctionColumnPairs[i] = aggregationFunctionColumnPair;
+ } else {
+ hasUnsupportedAggregationFunction = true;
+ break;
+ }
}
- FilterQueryTree rootFilterNode = RequestUtils.generateFilterQueryTree(brokerRequest);
- for (StarTreeV2 starTreeV2 : starTrees) {
- if (StarTreeUtils
- .isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs, groupByExpressions,
- rootFilterNode)) {
- _transformPlanNode = null;
- _starTreeTransformPlanNode =
- new StarTreeTransformPlanNode(starTreeV2, aggregationFunctionColumnPairs, groupByExpressions,
- rootFilterNode, brokerRequest.getDebugOptions());
- return;
+ if (!hasUnsupportedAggregationFunction) {
+ FilterQueryTree rootFilterNode = RequestUtils.generateFilterQueryTree(brokerRequest);
+ for (StarTreeV2 starTreeV2 : starTrees) {
+ if (StarTreeUtils
+ .isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs, _groupByExpressions,
+ rootFilterNode)) {
+ _transformPlanNode = null;
+ _starTreeTransformPlanNode =
+ new StarTreeTransformPlanNode(starTreeV2, aggregationFunctionColumnPairs, _groupByExpressions,
+ rootFilterNode, brokerRequest.getDebugOptions());
+ return;
+ }
}
}
}
}
+ Set<TransformExpressionTree> expressionsToTransform =
+ AggregationFunctionUtils.collectExpressionsToTransform(_aggregationFunctions, _groupByExpressions);
_transformPlanNode = new TransformPlanNode(_indexSegment, brokerRequest, expressionsToTransform);
_starTreeTransformPlanNode = null;
}
@@ -102,19 +115,19 @@ public class AggregationGroupByOrderByPlanNode implements PlanNode {
int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
if (_transformPlanNode != null) {
// Do not use star-tree
- return new AggregationGroupByOrderByOperator(_functionContexts, _groupBy, _maxInitialResultHolderCapacity,
- _numGroupsLimit, _transformPlanNode.run(), numTotalDocs, false);
+ return new AggregationGroupByOrderByOperator(_aggregationFunctions, _groupByExpressions,
+ _maxInitialResultHolderCapacity, _numGroupsLimit, _transformPlanNode.run(), numTotalDocs, false);
} else {
// Use star-tree
- return new AggregationGroupByOrderByOperator(_functionContexts, _groupBy, _maxInitialResultHolderCapacity,
- _numGroupsLimit, _starTreeTransformPlanNode.run(), numTotalDocs, true);
+ return new AggregationGroupByOrderByOperator(_aggregationFunctions, _groupByExpressions,
+ _maxInitialResultHolderCapacity, _numGroupsLimit, _starTreeTransformPlanNode.run(), numTotalDocs, true);
}
}
@Override
public void showTree(String prefix) {
- LOGGER.debug(prefix + "Aggregation Group-by Plan Node:");
- LOGGER.debug(prefix + "Operator: AggregationGroupByOperator");
+ LOGGER.debug(prefix + "Aggregation Group-by Order-by Plan Node:");
+ LOGGER.debug(prefix + "Operator: AggregationGroupByOrderByOperator");
LOGGER.debug(prefix + "Argument 0: IndexSegment - " + _indexSegment.getSegmentName());
LOGGER.debug(prefix + "Argument 1: Aggregations - " + _aggregationInfos);
LOGGER.debug(prefix + "Argument 2: GroupBy - " + _groupBy);
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByPlanNode.java
index 44b456c..543903e 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationGroupByPlanNode.java
@@ -18,7 +18,6 @@
*/
package org.apache.pinot.core.plan;
-import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.pinot.common.request.AggregationInfo;
@@ -29,7 +28,7 @@ import org.apache.pinot.common.utils.request.FilterQueryTree;
import org.apache.pinot.common.utils.request.RequestUtils;
import org.apache.pinot.core.indexsegment.IndexSegment;
import org.apache.pinot.core.operator.query.AggregationGroupByOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.startree.StarTreeUtils;
import org.apache.pinot.core.startree.plan.StarTreeTransformPlanNode;
@@ -50,8 +49,9 @@ public class AggregationGroupByPlanNode implements PlanNode {
private final int _maxInitialResultHolderCapacity;
private final int _numGroupsLimit;
private final List<AggregationInfo> _aggregationInfos;
- private final AggregationFunctionContext[] _functionContexts;
+ private final AggregationFunction[] _aggregationFunctions;
private final GroupBy _groupBy;
+ private final TransformExpressionTree[] _groupByExpressions;
private final TransformPlanNode _transformPlanNode;
private final StarTreeTransformPlanNode _starTreeTransformPlanNode;
@@ -61,37 +61,51 @@ public class AggregationGroupByPlanNode implements PlanNode {
_maxInitialResultHolderCapacity = maxInitialResultHolderCapacity;
_numGroupsLimit = numGroupsLimit;
_aggregationInfos = brokerRequest.getAggregationsInfo();
- _functionContexts = AggregationFunctionUtils.getAggregationFunctionContexts(brokerRequest);
+ _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
_groupBy = brokerRequest.getGroupBy();
+ List<String> groupByExpressions = _groupBy.getExpressions();
+ int numGroupByExpressions = groupByExpressions.size();
+ _groupByExpressions = new TransformExpressionTree[numGroupByExpressions];
+ for (int i = 0; i < numGroupByExpressions; i++) {
+ _groupByExpressions[i] = TransformExpressionTree.compileToExpressionTree(groupByExpressions.get(i));
+ }
List<StarTreeV2> starTrees = indexSegment.getStarTrees();
if (starTrees != null) {
if (!StarTreeUtils.isStarTreeDisabled(brokerRequest)) {
- Set<AggregationFunctionColumnPair> aggregationFunctionColumnPairs = new HashSet<>();
- for (AggregationInfo aggregationInfo : _aggregationInfos) {
- aggregationFunctionColumnPairs.add(AggregationFunctionUtils.getFunctionColumnPair(aggregationInfo));
- }
- Set<TransformExpressionTree> groupByExpressions = new HashSet<>();
- for (String expression : _groupBy.getExpressions()) {
- groupByExpressions.add(TransformExpressionTree.compileToExpressionTree(expression));
+ int numAggregationFunctions = _aggregationFunctions.length;
+ AggregationFunctionColumnPair[] aggregationFunctionColumnPairs =
+ new AggregationFunctionColumnPair[numAggregationFunctions];
+ boolean hasUnsupportedAggregationFunction = false;
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ AggregationFunctionColumnPair aggregationFunctionColumnPair =
+ AggregationFunctionUtils.getAggregationFunctionColumnPair(_aggregationFunctions[i]);
+ if (aggregationFunctionColumnPair != null) {
+ aggregationFunctionColumnPairs[i] = aggregationFunctionColumnPair;
+ } else {
+ hasUnsupportedAggregationFunction = true;
+ break;
+ }
}
- FilterQueryTree rootFilterNode = RequestUtils.generateFilterQueryTree(brokerRequest);
- for (StarTreeV2 starTreeV2 : starTrees) {
- if (StarTreeUtils
- .isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs, groupByExpressions,
- rootFilterNode)) {
- _transformPlanNode = null;
- _starTreeTransformPlanNode =
- new StarTreeTransformPlanNode(starTreeV2, aggregationFunctionColumnPairs, groupByExpressions,
- rootFilterNode, brokerRequest.getDebugOptions());
- return;
+ if (!hasUnsupportedAggregationFunction) {
+ FilterQueryTree rootFilterNode = RequestUtils.generateFilterQueryTree(brokerRequest);
+ for (StarTreeV2 starTreeV2 : starTrees) {
+ if (StarTreeUtils
+ .isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs, _groupByExpressions,
+ rootFilterNode)) {
+ _transformPlanNode = null;
+ _starTreeTransformPlanNode =
+ new StarTreeTransformPlanNode(starTreeV2, aggregationFunctionColumnPairs, _groupByExpressions,
+ rootFilterNode, brokerRequest.getDebugOptions());
+ return;
+ }
}
}
}
}
Set<TransformExpressionTree> expressionsToTransform =
- AggregationFunctionUtils.collectExpressionsToTransform(brokerRequest, _functionContexts);
+ AggregationFunctionUtils.collectExpressionsToTransform(_aggregationFunctions, _groupByExpressions);
_transformPlanNode = new TransformPlanNode(_indexSegment, brokerRequest, expressionsToTransform);
_starTreeTransformPlanNode = null;
}
@@ -101,11 +115,11 @@ public class AggregationGroupByPlanNode implements PlanNode {
int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
if (_transformPlanNode != null) {
// Do not use star-tree
- return new AggregationGroupByOperator(_functionContexts, _groupBy, _maxInitialResultHolderCapacity,
+ return new AggregationGroupByOperator(_aggregationFunctions, _groupByExpressions, _maxInitialResultHolderCapacity,
_numGroupsLimit, _transformPlanNode.run(), numTotalDocs, false);
} else {
// Use star-tree
- return new AggregationGroupByOperator(_functionContexts, _groupBy, _maxInitialResultHolderCapacity,
+ return new AggregationGroupByOperator(_aggregationFunctions, _groupByExpressions, _maxInitialResultHolderCapacity,
_numGroupsLimit, _starTreeTransformPlanNode.run(), numTotalDocs, true);
}
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
index 8932a37..ce0217b 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java
@@ -18,7 +18,6 @@
*/
package org.apache.pinot.core.plan;
-import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.pinot.common.request.AggregationInfo;
@@ -28,7 +27,7 @@ import org.apache.pinot.common.utils.request.FilterQueryTree;
import org.apache.pinot.common.utils.request.RequestUtils;
import org.apache.pinot.core.indexsegment.IndexSegment;
import org.apache.pinot.core.operator.query.AggregationOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.startree.StarTreeUtils;
import org.apache.pinot.core.startree.plan.StarTreeTransformPlanNode;
@@ -47,38 +46,50 @@ public class AggregationPlanNode implements PlanNode {
private final IndexSegment _indexSegment;
private final List<AggregationInfo> _aggregationInfos;
- private final AggregationFunctionContext[] _functionContexts;
+ private final AggregationFunction[] _aggregationFunctions;
private final TransformPlanNode _transformPlanNode;
private final StarTreeTransformPlanNode _starTreeTransformPlanNode;
public AggregationPlanNode(IndexSegment indexSegment, BrokerRequest brokerRequest) {
_indexSegment = indexSegment;
_aggregationInfos = brokerRequest.getAggregationsInfo();
- _functionContexts = AggregationFunctionUtils.getAggregationFunctionContexts(brokerRequest);
+ _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
List<StarTreeV2> starTrees = indexSegment.getStarTrees();
if (starTrees != null) {
if (!StarTreeUtils.isStarTreeDisabled(brokerRequest)) {
- Set<AggregationFunctionColumnPair> aggregationFunctionColumnPairs = new HashSet<>();
- for (AggregationInfo aggregationInfo : _aggregationInfos) {
- aggregationFunctionColumnPairs.add(AggregationFunctionUtils.getFunctionColumnPair(aggregationInfo));
+ int numAggregationFunctions = _aggregationFunctions.length;
+ AggregationFunctionColumnPair[] aggregationFunctionColumnPairs =
+ new AggregationFunctionColumnPair[numAggregationFunctions];
+ boolean hasUnsupportedAggregationFunction = false;
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ AggregationFunctionColumnPair aggregationFunctionColumnPair =
+ AggregationFunctionUtils.getAggregationFunctionColumnPair(_aggregationFunctions[i]);
+ if (aggregationFunctionColumnPair != null) {
+ aggregationFunctionColumnPairs[i] = aggregationFunctionColumnPair;
+ } else {
+ hasUnsupportedAggregationFunction = true;
+ break;
+ }
}
- FilterQueryTree rootFilterNode = RequestUtils.generateFilterQueryTree(brokerRequest);
- for (StarTreeV2 starTreeV2 : starTrees) {
- if (StarTreeUtils
- .isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs, null, rootFilterNode)) {
- _transformPlanNode = null;
- _starTreeTransformPlanNode =
- new StarTreeTransformPlanNode(starTreeV2, aggregationFunctionColumnPairs, null, rootFilterNode,
- brokerRequest.getDebugOptions());
- return;
+ if (!hasUnsupportedAggregationFunction) {
+ FilterQueryTree rootFilterNode = RequestUtils.generateFilterQueryTree(brokerRequest);
+ for (StarTreeV2 starTreeV2 : starTrees) {
+ if (StarTreeUtils
+ .isFitForStarTree(starTreeV2.getMetadata(), aggregationFunctionColumnPairs, null, rootFilterNode)) {
+ _transformPlanNode = null;
+ _starTreeTransformPlanNode =
+ new StarTreeTransformPlanNode(starTreeV2, aggregationFunctionColumnPairs, null, rootFilterNode,
+ brokerRequest.getDebugOptions());
+ return;
+ }
}
}
}
}
Set<TransformExpressionTree> expressionsToTransform =
- AggregationFunctionUtils.collectExpressionsToTransform(brokerRequest, _functionContexts);
+ AggregationFunctionUtils.collectExpressionsToTransform(_aggregationFunctions, null);
_transformPlanNode = new TransformPlanNode(_indexSegment, brokerRequest, expressionsToTransform);
_starTreeTransformPlanNode = null;
}
@@ -88,10 +99,10 @@ public class AggregationPlanNode implements PlanNode {
int numTotalDocs = _indexSegment.getSegmentMetadata().getTotalDocs();
if (_transformPlanNode != null) {
// Do not use star-tree
- return new AggregationOperator(_functionContexts, _transformPlanNode.run(), numTotalDocs, false);
+ return new AggregationOperator(_aggregationFunctions, _transformPlanNode.run(), numTotalDocs, false);
} else {
// Use star-tree
- return new AggregationOperator(_functionContexts, _starTreeTransformPlanNode.run(), numTotalDocs, true);
+ return new AggregationOperator(_aggregationFunctions, _starTreeTransformPlanNode.run(), numTotalDocs, true);
}
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/DictionaryBasedAggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/DictionaryBasedAggregationPlanNode.java
index 4bf368d..561f3eb 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/DictionaryBasedAggregationPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/DictionaryBasedAggregationPlanNode.java
@@ -21,10 +21,11 @@ package org.apache.pinot.core.plan;
import java.util.HashMap;
import java.util.Map;
import org.apache.pinot.common.request.BrokerRequest;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.common.Operator;
import org.apache.pinot.core.indexsegment.IndexSegment;
import org.apache.pinot.core.operator.query.DictionaryBasedAggregationOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.segment.index.readers.Dictionary;
import org.slf4j.Logger;
@@ -37,9 +38,9 @@ import org.slf4j.LoggerFactory;
public class DictionaryBasedAggregationPlanNode implements PlanNode {
private static final Logger LOGGER = LoggerFactory.getLogger(DictionaryBasedAggregationPlanNode.class);
+ private final IndexSegment _indexSegment;
+ private final AggregationFunction[] _aggregationFunctions;
private final Map<String, Dictionary> _dictionaryMap;
- private final AggregationFunctionContext[] _aggregationFunctionContexts;
- private IndexSegment _indexSegment;
/**
* Constructor for the class.
@@ -49,22 +50,18 @@ public class DictionaryBasedAggregationPlanNode implements PlanNode {
*/
public DictionaryBasedAggregationPlanNode(IndexSegment indexSegment, BrokerRequest brokerRequest) {
_indexSegment = indexSegment;
+ _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
_dictionaryMap = new HashMap<>();
-
- _aggregationFunctionContexts = AggregationFunctionUtils.getAggregationFunctionContexts(brokerRequest);
-
- for (AggregationFunctionContext aggregationFunctionContext : _aggregationFunctionContexts) {
- String column = aggregationFunctionContext.getColumnName();
- if (!_dictionaryMap.containsKey(column)) {
- _dictionaryMap.put(column, _indexSegment.getDataSource(column).getDictionary());
- }
+ for (AggregationFunction aggregationFunction : _aggregationFunctions) {
+ String column = ((TransformExpressionTree) aggregationFunction.getInputExpressions().get(0)).getValue();
+ _dictionaryMap.computeIfAbsent(column, k -> _indexSegment.getDataSource(k).getDictionary());
}
}
@Override
public Operator run() {
- return new DictionaryBasedAggregationOperator(_aggregationFunctionContexts,
- _indexSegment.getSegmentMetadata().getTotalDocs(), _dictionaryMap);
+ return new DictionaryBasedAggregationOperator(_aggregationFunctions, _dictionaryMap,
+ _indexSegment.getSegmentMetadata().getTotalDocs());
}
@Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/MetadataBasedAggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/MetadataBasedAggregationPlanNode.java
index 3ac5e5d..8125088 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/MetadataBasedAggregationPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/MetadataBasedAggregationPlanNode.java
@@ -24,13 +24,13 @@ import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.BrokerRequest;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.common.DataSource;
import org.apache.pinot.core.common.Operator;
import org.apache.pinot.core.indexsegment.IndexSegment;
import org.apache.pinot.core.operator.query.MetadataBasedAggregationOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
-import org.apache.pinot.core.segment.index.metadata.SegmentMetadata;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -43,7 +43,8 @@ public class MetadataBasedAggregationPlanNode implements PlanNode {
private final IndexSegment _indexSegment;
private final List<AggregationInfo> _aggregationInfos;
- private final BrokerRequest _brokerRequest;
+ private final AggregationFunction[] _aggregationFunctions;
+ private final Map<String, DataSource> _dataSourceMap;
/**
* Constructor for the class.
@@ -53,27 +54,21 @@ public class MetadataBasedAggregationPlanNode implements PlanNode {
*/
public MetadataBasedAggregationPlanNode(IndexSegment indexSegment, BrokerRequest brokerRequest) {
_indexSegment = indexSegment;
- _brokerRequest = brokerRequest;
_aggregationInfos = brokerRequest.getAggregationsInfo();
+ _aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
+ _dataSourceMap = new HashMap<>();
+ for (AggregationFunction aggregationFunction : _aggregationFunctions) {
+ if (aggregationFunction.getType() != AggregationFunctionType.COUNT) {
+ String column = ((TransformExpressionTree) aggregationFunction.getInputExpressions().get(0)).getValue();
+ _dataSourceMap.computeIfAbsent(column, _indexSegment::getDataSource);
+ }
+ }
}
@Override
public Operator run() {
- SegmentMetadata segmentMetadata = _indexSegment.getSegmentMetadata();
- AggregationFunctionContext[] aggregationFunctionContexts =
- AggregationFunctionUtils.getAggregationFunctionContexts(_brokerRequest);
-
- Map<String, DataSource> dataSourceMap = new HashMap<>();
- for (AggregationFunctionContext aggregationFunctionContext : aggregationFunctionContexts) {
- if (aggregationFunctionContext.getAggregationFunction().getType() != AggregationFunctionType.COUNT) {
- String column = aggregationFunctionContext.getColumnName();
- if (!dataSourceMap.containsKey(column)) {
- dataSourceMap.put(column, _indexSegment.getDataSource(column));
- }
- }
- }
-
- return new MetadataBasedAggregationOperator(aggregationFunctionContexts, segmentMetadata, dataSourceMap);
+ return new MetadataBasedAggregationOperator(_aggregationFunctions, _indexSegment.getSegmentMetadata(),
+ _dataSourceMap);
}
@Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java
index 182f79b..2103115 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/maker/InstancePlanMakerImplV2.java
@@ -26,7 +26,6 @@ import java.util.concurrent.ExecutorService;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.BrokerRequest;
-import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.data.manager.SegmentDataManager;
import org.apache.pinot.core.indexsegment.IndexSegment;
import org.apache.pinot.core.plan.AggregationGroupByOrderByPlanNode;
@@ -205,9 +204,9 @@ public class InstancePlanMakerImplV2 implements PlanMaker {
AggregationFunctionType.getAggregationFunctionType(aggregationInfo.getAggregationType());
if (functionType
.isOfType(AggregationFunctionType.MIN, AggregationFunctionType.MAX, AggregationFunctionType.MINMAXRANGE)) {
- String expression = AggregationFunctionUtils.getAggregationExpressions(aggregationInfo).get(0);
- if (TransformExpressionTree.compileToExpressionTree(expression).isColumn()) {
- Dictionary dictionary = indexSegment.getDataSource(expression).getDictionary();
+ String column = AggregationFunctionUtils.getArguments(aggregationInfo).get(0);
+ if (indexSegment.getColumnNames().contains(column)) {
+ Dictionary dictionary = indexSegment.getDataSource(column).getDictionary();
return dictionary != null && dictionary.isSorted();
}
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/AggregationFunctionContext.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/AggregationFunctionContext.java
deleted file mode 100644
index 9d1027f..0000000
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/AggregationFunctionContext.java
+++ /dev/null
@@ -1,85 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.pinot.core.query.aggregation;
-
-import com.google.common.base.Preconditions;
-import java.util.List;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
-
-
-/**
- * This class caches miscellaneous data to perform efficient aggregation.
- *
- * TODO: Remove this class, as it no longer provides any value after aggregation functions now store
- * their arguments.
- */
-public class AggregationFunctionContext {
- private final AggregationFunction _aggregationFunction;
- private final List<String> _expressions;
- private final String columnName;
-
- public AggregationFunctionContext(AggregationFunction aggregationFunction, List<String> expressions) {
- Preconditions.checkArgument(expressions.size() >= 1, "Aggregation functions require at least one argument.");
- _aggregationFunction = aggregationFunction;
- _expressions = expressions;
- columnName = AggregationFunctionUtils.concatArgs(expressions);
- }
-
- /**
- * Returns the aggregation function.
- */
- public AggregationFunction getAggregationFunction() {
- return _aggregationFunction;
- }
-
- /**
- * Returns the arguments for the aggregation function.
- *
- * @return List of Strings containing the arguments for the aggregation function.
- */
- public List<String> getExpressions() {
- return _expressions;
- }
-
- /**
- * Returns the column for aggregation function.
- *
- * @return Aggregation Column (could be column name or UDF expression).
- */
- public String getColumnName() {
- return columnName;
- }
-
- /**
- * Returns the aggregation column name for the results.
- * <p>E.g. AVG(foo) -> avg_foo
- */
- public String getAggregationColumnName() {
- return _aggregationFunction.getColumnName();
- }
-
- /**
- * Returns the aggregation column name for the result table.
- * <p>E.g. AVGMV(foo) -> avgMV(foo)
- */
- public String getResultColumnName() {
- return _aggregationFunction.getResultColumnName();
- }
-}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DefaultAggregationExecutor.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DefaultAggregationExecutor.java
index 4ab0754..706bf1d 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DefaultAggregationExecutor.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DefaultAggregationExecutor.java
@@ -19,70 +19,42 @@
package org.apache.pinot.core.query.aggregation;
import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
import java.util.List;
-import java.util.Map;
-import org.apache.pinot.common.function.AggregationFunctionType;
-import org.apache.pinot.common.request.transform.TransformExpressionTree;
-import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.operator.blocks.TransformBlock;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
public class DefaultAggregationExecutor implements AggregationExecutor {
- protected final int _numFunctions;
- protected final AggregationFunction[] _functions;
- protected final AggregationResultHolder[] _resultHolders;
- protected final TransformExpressionTree[][] _expressions;
-
- public DefaultAggregationExecutor(AggregationFunctionContext[] functionContexts) {
- _numFunctions = functionContexts.length;
- _functions = new AggregationFunction[_numFunctions];
- _resultHolders = new AggregationResultHolder[_numFunctions];
-
- _expressions = new TransformExpressionTree[_numFunctions][];
- for (int i = 0; i < _numFunctions; i++) {
- AggregationFunction function = functionContexts[i].getAggregationFunction();
- _functions[i] = function;
- _resultHolders[i] = _functions[i].createAggregationResultHolder();
-
- if (function.getType() != AggregationFunctionType.COUNT) {
- // count(*) does not have a column so handle rest of the aggregate
- // functions -- sum, min, max etc
-
- List<TransformExpressionTree> inputExpressionsList = function.getInputExpressions();
- _expressions[i] = inputExpressionsList.toArray(new TransformExpressionTree[0]);
- }
+ protected final AggregationFunction[] _aggregationFunctions;
+ protected final AggregationResultHolder[] _aggregationResultHolders;
+
+ public DefaultAggregationExecutor(AggregationFunction[] aggregationFunctions) {
+ _aggregationFunctions = aggregationFunctions;
+ int numAggregationFunctions = aggregationFunctions.length;
+ _aggregationResultHolders = new AggregationResultHolder[numAggregationFunctions];
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ _aggregationResultHolders[i] = aggregationFunctions[i].createAggregationResultHolder();
}
}
@Override
public void aggregate(TransformBlock transformBlock) {
+ int numAggregationFunctions = _aggregationFunctions.length;
int length = transformBlock.getNumDocs();
- for (int i = 0; i < _numFunctions; i++) {
- AggregationFunction function = _functions[i];
- AggregationResultHolder resultHolder = _resultHolders[i];
- if (function.getType() == AggregationFunctionType.COUNT) {
- // handle count(*) function
- function.aggregate(length, resultHolder, Collections.emptyMap());
- } else {
- // handle rest of the aggregate functions -- sum, min, max etc
- Map<String, BlockValSet> blockValSetMap = new HashMap<>();
-
- for (int j = 0; j < _expressions[i].length; j++) {
- blockValSetMap.put(_expressions[i][j].toString(), transformBlock.getBlockValueSet(_expressions[i][j]));
- }
- function.aggregate(length, resultHolder, blockValSetMap);
- }
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ AggregationFunction aggregationFunction = _aggregationFunctions[i];
+ aggregationFunction.aggregate(length, _aggregationResultHolders[i],
+ AggregationFunctionUtils.getBlockValSetMap(aggregationFunction, transformBlock));
}
}
@Override
public List<Object> getResult() {
- List<Object> aggregationResults = new ArrayList<>(_numFunctions);
- for (int i = 0; i < _numFunctions; i++) {
- aggregationResults.add(_functions[i].extractAggregationResult(_resultHolders[i]));
+ int numFunctions = _aggregationFunctions.length;
+ List<Object> aggregationResults = new ArrayList<>(numFunctions);
+ for (int i = 0; i < numFunctions; i++) {
+ aggregationResults.add(_aggregationFunctions[i].extractAggregationResult(_aggregationResultHolders[i]));
}
return aggregationResults;
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DistinctTable.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DistinctTable.java
index de51993..d294218 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DistinctTable.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/DistinctTable.java
@@ -21,7 +21,6 @@ package org.apache.pinot.core.query.aggregation;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
@@ -36,6 +35,7 @@ import org.apache.pinot.core.common.datatable.DataTableFactory;
import org.apache.pinot.core.data.table.BaseTable;
import org.apache.pinot.core.data.table.Key;
import org.apache.pinot.core.data.table.Record;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.spi.utils.ByteArray;
@@ -59,7 +59,7 @@ public class DistinctTable extends BaseTable {
// NOTE: The passed in capacity is calculated based on the LIMIT in the query as Math.max(limit * 5, 5000). When
// LIMIT is smaller than (64 * 1024 * 0.75 (load factor) / 5 = 9830), then it is guaranteed that no resize is
// required.
- super(dataSchema, Collections.emptyList(), orderBy, capacity);
+ super(dataSchema, new AggregationFunction[0], orderBy, capacity);
int initialCapacity = Math.min(MAX_INITIAL_CAPACITY, HashUtil.getHashMapCapacity(capacity));
_uniqueRecordsSet = new HashSet<>(initialCapacity);
_noMoreNewRecords = false;
@@ -162,8 +162,8 @@ public class DistinctTable extends BaseTable {
// information to pass to super class so just pass null, empty lists
// and the broker will set the correct information before merging the
// data tables.
- super(new DataSchema(new String[0], new DataSchema.ColumnDataType[0]), Collections.emptyList(), new ArrayList<>(),
- 0);
+ super(new DataSchema(new String[0], new DataSchema.ColumnDataType[0]), new AggregationFunction[0],
+ new ArrayList<>(), 0);
DataTable dataTable = DataTableFactory.getDataTable(byteBuffer);
_dataSchema = dataTable.getDataSchema();
_uniqueRecordsSet = new HashSet<>();
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java
index 08e4868..362309c 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunction.java
@@ -76,21 +76,22 @@ public interface AggregationFunction<IntermediateResult, FinalResult extends Com
/**
* Performs aggregation on the given block value sets (aggregation only).
*/
- void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap);
+ void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap);
/**
* Performs aggregation on the given group key array and block value sets (aggregation group-by on single-value
* columns).
*/
void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSets);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap);
/**
* Performs aggregation on the given group keys array and block value sets (aggregation group-by on multi-value
* columns).
*/
void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSets);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap);
/**
* Extracts the intermediate result from the aggregation result holder (aggregation only).
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index 90f99f6..f21f7fb 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -19,7 +19,6 @@
package org.apache.pinot.core.query.aggregation.function;
import com.google.common.base.Preconditions;
-import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nullable;
import org.apache.pinot.common.function.AggregationFunctionType;
@@ -45,45 +44,79 @@ public class AggregationFunctionFactory {
public static AggregationFunction getAggregationFunction(AggregationInfo aggregationInfo,
@Nullable BrokerRequest brokerRequest) {
String functionName = aggregationInfo.getAggregationType();
- List<String> expressions = AggregationFunctionUtils.getAggregationExpressions(aggregationInfo);
+ List<String> arguments = AggregationFunctionUtils.getArguments(aggregationInfo);
try {
String upperCaseFunctionName = functionName.toUpperCase();
+ String column = arguments.get(0);
if (upperCaseFunctionName.startsWith("PERCENTILE")) {
String remainingFunctionName = upperCaseFunctionName.substring(10);
- List<String> args = new ArrayList<>(expressions);
- if (remainingFunctionName.matches("\\d+")) {
- // Percentile
- args.add(remainingFunctionName);
- return new PercentileAggregationFunction(args);
- } else if (remainingFunctionName.matches("EST\\d+")) {
- // PercentileEst
- args.add(remainingFunctionName.substring(3));
- return new PercentileEstAggregationFunction(args);
- } else if (remainingFunctionName.matches("TDIGEST\\d+")) {
- // PercentileTDigest
- args.add(remainingFunctionName.substring(7));
- return new PercentileTDigestAggregationFunction(args);
- } else if (remainingFunctionName.matches("\\d+MV")) {
- // PercentileMV
- args.add(remainingFunctionName.substring(0, remainingFunctionName.length() - 2));
- return new PercentileMVAggregationFunction(args);
- } else if (remainingFunctionName.matches("EST\\d+MV")) {
- // PercentileEstMV
- args.add(remainingFunctionName.substring(3, remainingFunctionName.length() - 2));
- return new PercentileEstMVAggregationFunction(args);
- } else if (remainingFunctionName.matches("TDIGEST\\d+MV")) {
- // PercentileTDigestMV
- args.add(remainingFunctionName.substring(7, remainingFunctionName.length() - 2));
- return new PercentileTDigestMVAggregationFunction(args);
- } else {
- throw new IllegalArgumentException();
+ int numArguments = arguments.size();
+ if (numArguments == 1) {
+ // Single argument percentile (e.g. Percentile99(foo), PercentileTDigest95(bar), etc.)
+ if (remainingFunctionName.matches("\\d+")) {
+ // Percentile
+ return new PercentileAggregationFunction(column,
+ AggregationFunctionUtils.parsePercentile(remainingFunctionName));
+ } else if (remainingFunctionName.matches("EST\\d+")) {
+ // PercentileEst
+ String percentileString = remainingFunctionName.substring(3);
+ return new PercentileEstAggregationFunction(column,
+ AggregationFunctionUtils.parsePercentile(percentileString));
+ } else if (remainingFunctionName.matches("TDIGEST\\d+")) {
+ // PercentileTDigest
+ String percentileString = remainingFunctionName.substring(7);
+ return new PercentileTDigestAggregationFunction(column,
+ AggregationFunctionUtils.parsePercentile(percentileString));
+ } else if (remainingFunctionName.matches("\\d+MV")) {
+ // PercentileMV
+ String percentileString = remainingFunctionName.substring(0, remainingFunctionName.length() - 2);
+ return new PercentileMVAggregationFunction(column,
+ AggregationFunctionUtils.parsePercentile(percentileString));
+ } else if (remainingFunctionName.matches("EST\\d+MV")) {
+ // PercentileEstMV
+ String percentileString = remainingFunctionName.substring(3, remainingFunctionName.length() - 2);
+ return new PercentileEstMVAggregationFunction(column,
+ AggregationFunctionUtils.parsePercentile(percentileString));
+ } else if (remainingFunctionName.matches("TDIGEST\\d+MV")) {
+ // PercentileTDigestMV
+ String percentileString = remainingFunctionName.substring(7, remainingFunctionName.length() - 2);
+ return new PercentileTDigestMVAggregationFunction(column,
+ AggregationFunctionUtils.parsePercentile(percentileString));
+ }
+ } else if (numArguments == 2) {
+ // Double arguments percentile (e.g. percentile(foo, 99), percentileTDigest(bar, 95), etc.)
+ int percentile = AggregationFunctionUtils.parsePercentile(arguments.get(1));
+ if (remainingFunctionName.isEmpty()) {
+ // Percentile
+ return new PercentileAggregationFunction(column, percentile);
+ }
+ if (remainingFunctionName.equals("EST")) {
+ // PercentileEst
+ return new PercentileEstAggregationFunction(column, percentile);
+ }
+ if (remainingFunctionName.equals("TDIGEST")) {
+ // PercentileTDigest
+ return new PercentileTDigestAggregationFunction(column, percentile);
+ }
+ if (remainingFunctionName.equals("MV")) {
+ // PercentileMV
+ return new PercentileMVAggregationFunction(column, percentile);
+ }
+ if (remainingFunctionName.equals("ESTMV")) {
+ // PercentileEstMV
+ return new PercentileEstMVAggregationFunction(column, percentile);
+ }
+ if (remainingFunctionName.equals("TDIGESTMV")) {
+ // PercentileTDigestMV
+ return new PercentileTDigestMVAggregationFunction(column, percentile);
+ }
}
+ throw new IllegalArgumentException("Invalid percentile function");
} else {
- String column = expressions.get(0);
switch (AggregationFunctionType.valueOf(upperCaseFunctionName)) {
case COUNT:
- return new CountAggregationFunction(column);
+ return new CountAggregationFunction();
case MIN:
return new MinAggregationFunction(column);
case MAX:
@@ -103,7 +136,7 @@ public class AggregationFunctionFactory {
case FASTHLL:
return new FastHLLAggregationFunction(column);
case DISTINCTCOUNTTHETASKETCH:
- return new DistinctCountThetaSketchAggregationFunction(expressions);
+ return new DistinctCountThetaSketchAggregationFunction(arguments);
case COUNTMV:
return new CountMVAggregationFunction(column);
case MINMV:
@@ -125,14 +158,13 @@ public class AggregationFunctionFactory {
case DISTINCT:
Preconditions.checkState(brokerRequest != null,
"Broker request must be provided for 'DISTINCT' aggregation function");
- return new DistinctAggregationFunction(expressions, brokerRequest.getOrderBy(),
- brokerRequest.getLimit());
+ return new DistinctAggregationFunction(arguments, brokerRequest.getOrderBy(), brokerRequest.getLimit());
default:
throw new IllegalArgumentException();
}
}
} catch (Exception e) {
- throw new BadQueryRequestException("Invalid aggregation function name: " + functionName);
+ throw new BadQueryRequestException("Invalid aggregation: " + aggregationInfo);
}
}
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
index 995440f..8b75740 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils.java
@@ -22,16 +22,20 @@ import com.google.common.base.Preconditions;
import com.google.common.math.DoubleMath;
import java.io.Serializable;
import java.util.Arrays;
-import java.util.LinkedHashSet;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import java.util.Locale;
+import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.BrokerRequest;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.operator.blocks.TransformBlock;
import org.apache.pinot.core.startree.v2.AggregationFunctionColumnPair;
import org.apache.pinot.parsers.CompilerConstants;
@@ -44,79 +48,26 @@ public class AggregationFunctionUtils {
}
/**
- * Extracts the aggregation column (could be column name or UDF expression) from the {@link AggregationInfo}.
- *
- *
- */
- /**
- * Returns the arguments for {@link AggregationFunction} as List of Strings.
- * For backward compatibility, it uses the new Thrift field `expressions` if found, or else,
- * falls back to the previous aggregationParams based approach.
- *
- * @param aggregationInfo Aggregation Info
- * @return List of aggregation function arguments
+ * Extracts the aggregation function arguments (could be column name, transform function or constant) from the
+ * {@link AggregationInfo} as an array of Strings.
+ * <p>NOTE: For backward-compatibility, uses the new Thrift field `expressions` if found, or falls back to the old
+ * aggregationParams based approach.
*/
- public static List<String> getAggregationExpressions(AggregationInfo aggregationInfo) {
+ public static List<String> getArguments(AggregationInfo aggregationInfo) {
List<String> expressions = aggregationInfo.getExpressions();
if (expressions != null) {
return expressions;
+ } else {
+ // NOTE: When the server is upgraded before the broker, the expressions won't be set. Falls back to the old
+ // aggregationParams based approach.
+ String column = aggregationInfo.getAggregationParams().get(CompilerConstants.COLUMN_KEY_IN_AGGREGATION_INFO);
+ return Arrays.asList(column.split(CompilerConstants.AGGREGATION_FUNCTION_ARG_SEPARATOR));
}
-
- String params = aggregationInfo.getAggregationParams().get(CompilerConstants.COLUMN_KEY_IN_AGGREGATION_INFO);
- return Arrays.asList(params.split(CompilerConstants.AGGREGATION_FUNCTION_ARG_SEPARATOR));
}
/**
- * Creates an {@link AggregationFunctionColumnPair} from the {@link AggregationInfo}.
- * Asserts that the function only expects one argument.
+ * Creates an array of {@link AggregationFunction}s based on the given {@link BrokerRequest}.
*/
- public static AggregationFunctionColumnPair getFunctionColumnPair(AggregationInfo aggregationInfo) {
- List<String> aggregationExpressions = getAggregationExpressions(aggregationInfo);
- int numExpressions = aggregationExpressions.size();
- AggregationFunctionType functionType =
- AggregationFunctionType.getAggregationFunctionType(aggregationInfo.getAggregationType());
- Preconditions
- .checkState(numExpressions == 1, "Expected one argument for '" + functionType + "', got: " + numExpressions);
- return new AggregationFunctionColumnPair(functionType, aggregationExpressions.get(0));
- }
-
- public static boolean isDistinct(AggregationFunctionContext[] functionContexts) {
- return functionContexts.length == 1
- && functionContexts[0].getAggregationFunction().getType() == AggregationFunctionType.DISTINCT;
- }
-
- /**
- * Creates an {@link AggregationFunctionContext} from the {@link AggregationInfo}.
- * NOTE: This method does not work for {@code DISTINCT} aggregation function.
- * TODO: Remove this method and always pass in the broker request
- */
- public static AggregationFunctionContext getAggregationFunctionContext(AggregationInfo aggregationInfo) {
- return getAggregationFunctionContext(aggregationInfo, null);
- }
-
- /**
- * NOTE: Broker request cannot be {@code null} for {@code DISTINCT} aggregation function.
- * TODO: Always pass in non-null broker request
- */
- public static AggregationFunctionContext getAggregationFunctionContext(AggregationInfo aggregationInfo,
- @Nullable BrokerRequest brokerRequest) {
- List<String> aggregationExpressions = getAggregationExpressions(aggregationInfo);
- AggregationFunction aggregationFunction =
- AggregationFunctionFactory.getAggregationFunction(aggregationInfo, brokerRequest);
- return new AggregationFunctionContext(aggregationFunction, aggregationExpressions);
- }
-
- public static AggregationFunctionContext[] getAggregationFunctionContexts(BrokerRequest brokerRequest) {
- List<AggregationInfo> aggregationInfos = brokerRequest.getAggregationsInfo();
- int numAggregationFunctions = aggregationInfos.size();
- AggregationFunctionContext[] aggregationFunctionContexts = new AggregationFunctionContext[numAggregationFunctions];
- for (int i = 0; i < numAggregationFunctions; i++) {
- AggregationInfo aggregationInfo = aggregationInfos.get(i);
- aggregationFunctionContexts[i] = getAggregationFunctionContext(aggregationInfo, brokerRequest);
- }
- return aggregationFunctionContexts;
- }
-
public static AggregationFunction[] getAggregationFunctions(BrokerRequest brokerRequest) {
List<AggregationInfo> aggregationInfos = brokerRequest.getAggregationsInfo();
int numAggregationFunctions = aggregationInfos.size();
@@ -128,6 +79,29 @@ public class AggregationFunctionUtils {
return aggregationFunctions;
}
+ /**
+ * (For Star-Tree) Creates an {@link AggregationFunctionColumnPair} from the {@link AggregationFunction}. Returns
+ * {@code null} if the {@link AggregationFunction} cannot be represented as an {@link AggregationFunctionColumnPair}
+ * (e.g. has multiple arguments, argument is not column etc.).
+ */
+ @Nullable
+ public static AggregationFunctionColumnPair getAggregationFunctionColumnPair(
+ AggregationFunction aggregationFunction) {
+ AggregationFunctionType aggregationFunctionType = aggregationFunction.getType();
+ if (aggregationFunctionType == AggregationFunctionType.COUNT) {
+ return AggregationFunctionColumnPair.COUNT_STAR;
+ }
+ //noinspection unchecked
+ List<TransformExpressionTree> inputExpressions = aggregationFunction.getInputExpressions();
+ if (inputExpressions.size() == 1) {
+ TransformExpressionTree inputExpression = inputExpressions.get(0);
+ if (inputExpression.isColumn()) {
+ return new AggregationFunctionColumnPair(aggregationFunctionType, inputExpression.getValue());
+ }
+ }
+ return null;
+ }
+
public static boolean[] getAggregationFunctionsSelectStatus(List<AggregationInfo> aggregationInfos) {
int numAggregationFunctions = aggregationInfos.size();
boolean[] aggregationFunctionsStatus = new boolean[numAggregationFunctions];
@@ -164,13 +138,20 @@ public class AggregationFunctionUtils {
/**
* Utility function to parse percentile value from string.
- * Asserts that percentile value is within 0 and 100.
+ * <p>Asserts that percentile value is within 0 and 100.
+ * <p>NOTE: When percentileString is from the second argument (e.g. percentile(foo, 99), percentileTDigest(bar, 95),
+ * etc.), it might be standardized into single-quoted format.
*
* @param percentileString Input String
* @return Percentile value parsed from String.
*/
public static int parsePercentile(String percentileString) {
- int percentile = Integer.parseInt(percentileString);
+ int percentile;
+ if (percentileString.charAt(0) == '\'') {
+ percentile = Integer.parseInt(percentileString.substring(1, percentileString.length() - 1));
+ } else {
+ percentile = Integer.parseInt(percentileString);
+ }
Preconditions.checkState(percentile >= 0 && percentile <= 100);
return percentile;
}
@@ -181,37 +162,63 @@ public class AggregationFunctionUtils {
* @param arguments Arguments to concatenate
* @return Concatenated String of arguments
*/
- public static String concatArgs(List<String> arguments) {
- return (arguments.size() > 1) ? String.join(CompilerConstants.AGGREGATION_FUNCTION_ARG_SEPARATOR, arguments)
- : arguments.get(0);
+ public static String concatArgs(String[] arguments) {
+ return arguments.length > 1 ? String.join(CompilerConstants.AGGREGATION_FUNCTION_ARG_SEPARATOR, arguments)
+ : arguments[0];
}
/**
- * Compiles and returns all transform expressions required for computing the aggregation, group-by
- * and order-by
- *
- * @param brokerRequest Broker Request
- * @param functionContexts Aggregation Function contexts
- * @return Set of compiled expressions in the aggregation, group-by and order-by clauses
+ * Collects all transform expressions required for aggregation/group-by queries.
+ * <p>NOTE: We don't need to consider order-by columns here as the ordering is only allowed for aggregation functions
+ * or group-by expressions.
*/
- public static Set<TransformExpressionTree> collectExpressionsToTransform(BrokerRequest brokerRequest,
- AggregationFunctionContext[] functionContexts) {
-
- Set<TransformExpressionTree> expressionTrees = new LinkedHashSet<>();
- for (AggregationFunctionContext functionContext : functionContexts) {
- AggregationFunction function = functionContext.getAggregationFunction();
- expressionTrees.addAll(function.getInputExpressions());
+ public static Set<TransformExpressionTree> collectExpressionsToTransform(AggregationFunction[] aggregationFunctions,
+ @Nullable TransformExpressionTree[] groupByExpressions) {
+ Set<TransformExpressionTree> expressions = new HashSet<>();
+ for (AggregationFunction aggregationFunction : aggregationFunctions) {
+ //noinspection unchecked
+ expressions.addAll(aggregationFunction.getInputExpressions());
+ }
+ if (groupByExpressions != null) {
+ expressions.addAll(Arrays.asList(groupByExpressions));
}
+ return expressions;
+ }
- // Extract group-by expressions
- if (brokerRequest.isSetGroupBy()) {
- for (String expression : brokerRequest.getGroupBy().getExpressions()) {
- expressionTrees.add(TransformExpressionTree.compileToExpressionTree(expression));
- }
+ /**
+ * Creates a map from expression required by the {@link AggregationFunction} to {@link BlockValSet} fetched from the
+ * {@link TransformBlock}.
+ */
+ public static Map<TransformExpressionTree, BlockValSet> getBlockValSetMap(AggregationFunction aggregationFunction,
+ TransformBlock transformBlock) {
+ //noinspection unchecked
+ List<TransformExpressionTree> expressions = aggregationFunction.getInputExpressions();
+ int numExpressions = expressions.size();
+ if (numExpressions == 0) {
+ return Collections.emptyMap();
}
+ if (numExpressions == 1) {
+ TransformExpressionTree expression = expressions.get(0);
+ return Collections.singletonMap(expression, transformBlock.getBlockValueSet(expression));
+ }
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap = new HashMap<>();
+ for (TransformExpressionTree expression : expressions) {
+ blockValSetMap.put(expression, transformBlock.getBlockValueSet(expression));
+ }
+ return blockValSetMap;
+ }
- // TODO: Add order-by expressions when available in brokerRequest for aggregation queries.
- // The current order-by implementation assumes that ordering will be on aggregation/group-by columns.
- return expressionTrees;
+ /**
+ * (For Star-Tree) Creates a map from expression required by the {@link AggregationFunctionColumnPair} to
+ * {@link BlockValSet} fetched from the {@link TransformBlock}.
+ * <p>NOTE: We construct the map with original column name as the key but fetch BlockValSet with the aggregation
+ * function pair so that the aggregation result column name is consistent with or without star-tree.
+ */
+ public static Map<TransformExpressionTree, BlockValSet> getBlockValSetMap(
+ AggregationFunctionColumnPair aggregationFunctionColumnPair, TransformBlock transformBlock) {
+ TransformExpressionTree expression = new TransformExpressionTree(TransformExpressionTree.ExpressionType.IDENTIFIER,
+ aggregationFunctionColumnPair.getColumn(), null);
+ BlockValSet blockValSet = transformBlock.getBlockValueSet(aggregationFunctionColumnPair.toColumnName());
+ return Collections.singletonMap(expression, blockValSet);
}
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java
index 33d21a6..52f2d6c 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgAggregationFunction.java
@@ -18,8 +18,6 @@
*/
package org.apache.pinot.core.query.aggregation.function;
-import java.util.Collections;
-import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -31,23 +29,14 @@ import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.function.customobject.AvgPair;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
-import org.apache.pinot.pql.parsers.pql2.ast.IdentifierAstNode;
import org.apache.pinot.spi.data.FieldSpec.DataType;
-public class AvgAggregationFunction implements AggregationFunction<AvgPair, Double> {
+public class AvgAggregationFunction extends BaseSingleInputAggregationFunction<AvgPair, Double> {
private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY;
- protected final String _column;
- protected final List<TransformExpressionTree> _inputExpressions;
-
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public AvgAggregationFunction(String column) {
- _column = column;
- _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+ super(column);
}
@Override
@@ -56,21 +45,6 @@ public class AvgAggregationFunction implements AggregationFunction<AvgPair, Doub
}
@Override
- public String getColumnName() {
- return getType().getName() + "_" + _column;
- }
-
- @Override
- public String getResultColumnName() {
- return getType().getName().toLowerCase() + "(" + _column + ")";
- }
-
- @Override
- public List<TransformExpressionTree> getInputExpressions() {
- return _inputExpressions;
- }
-
- @Override
public void accept(AggregationFunctionVisitorBase visitor) {
visitor.visit(this);
}
@@ -87,8 +61,8 @@ public class AvgAggregationFunction implements AggregationFunction<AvgPair, Doub
@Override
public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
@@ -96,7 +70,7 @@ public class AvgAggregationFunction implements AggregationFunction<AvgPair, Doub
for (int i = 0; i < length; i++) {
sum += doubleValues[i];
}
- setAggregationResult(aggregationResultHolder, sum, (long) length);
+ setAggregationResult(aggregationResultHolder, sum, length);
} else {
// Serialized AvgPair
byte[][] bytesValues = blockValSet.getBytesValuesSV();
@@ -122,8 +96,8 @@ public class AvgAggregationFunction implements AggregationFunction<AvgPair, Doub
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
@@ -142,8 +116,8 @@ public class AvgAggregationFunction implements AggregationFunction<AvgPair, Doub
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java
index f45b41f..abe319b 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AvgMVAggregationFunction.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.query.aggregation.function;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -27,10 +28,6 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
public class AvgMVAggregationFunction extends AvgAggregationFunction {
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public AvgMVAggregationFunction(String column) {
super(column);
}
@@ -46,8 +43,9 @@ public class AvgMVAggregationFunction extends AvgAggregationFunction {
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
double sum = 0.0;
long count = 0L;
for (int i = 0; i < length; i++) {
@@ -62,8 +60,8 @@ public class AvgMVAggregationFunction extends AvgAggregationFunction {
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
aggregateOnGroupKey(groupKeyArray[i], groupByResultHolder, valuesArray[i]);
}
@@ -71,8 +69,8 @@ public class AvgMVAggregationFunction extends AvgAggregationFunction {
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
double[] values = valuesArray[i];
for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseSingleInputAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseSingleInputAggregationFunction.java
new file mode 100644
index 0000000..43641b2
--- /dev/null
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/BaseSingleInputAggregationFunction.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import java.util.Collections;
+import java.util.List;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
+
+
+/**
+ * Base implementation of {@link AggregationFunction} with single input expression.
+ */
+public abstract class BaseSingleInputAggregationFunction<I, F extends Comparable> implements AggregationFunction<I, F> {
+ protected final String _column;
+ protected final TransformExpressionTree _expression;
+
+ /**
+ * Constructor for the class.
+ *
+ * @param column Column to aggregate on (could be column name or transform function).
+ */
+ public BaseSingleInputAggregationFunction(String column) {
+ _column = column;
+ _expression = TransformExpressionTree.compileToExpressionTree(column);
+ }
+
+ @Override
+ public String getColumnName() {
+ return getType().getName() + "_" + _column;
+ }
+
+ @Override
+ public String getResultColumnName() {
+ return getType().getName().toLowerCase() + "(" + _column + ")";
+ }
+
+ @Override
+ public List<TransformExpressionTree> getInputExpressions() {
+ return Collections.singletonList(_expression);
+ }
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java
index 20ce56b..a437ce0 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountAggregationFunction.java
@@ -29,21 +29,17 @@ import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.DoubleAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.DoubleGroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import org.apache.pinot.core.startree.v2.AggregationFunctionColumnPair;
public class CountAggregationFunction implements AggregationFunction<Long, Long> {
- private static final String COLUMN_NAME = AggregationFunctionType.COUNT.getName() + "_star";
+ private static final String COLUMN_NAME = "count_star";
+ private static final String RESULT_COLUMN_NAME = "count(*)";
private static final double DEFAULT_INITIAL_VALUE = 0.0;
-
- protected final String _column;
-
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
- public CountAggregationFunction(String column) {
- _column = column;
- }
+ // Special expression used by star-tree to pass in BlockValSet
+ private static final TransformExpressionTree STAR_TREE_COUNT_STAR_EXPRESSION =
+ new TransformExpressionTree(TransformExpressionTree.ExpressionType.IDENTIFIER, AggregationFunctionColumnPair.STAR,
+ null);
@Override
public AggregationFunctionType getType() {
@@ -57,7 +53,7 @@ public class CountAggregationFunction implements AggregationFunction<Long, Long>
@Override
public String getResultColumnName() {
- return AggregationFunctionType.COUNT.getName().toLowerCase() + "(*)";
+ return RESULT_COLUMN_NAME;
}
@Override
@@ -82,12 +78,12 @@ public class CountAggregationFunction implements AggregationFunction<Long, Long>
@Override
public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
if (blockValSetMap.size() == 0) {
aggregationResultHolder.setValue(aggregationResultHolder.getDoubleResult() + length);
} else {
// Star-tree pre-aggregated values
- long[] valueArray = blockValSetMap.get(_column).getLongValuesSV();
+ long[] valueArray = blockValSetMap.get(STAR_TREE_COUNT_STAR_EXPRESSION).getLongValuesSV();
long count = 0;
for (int i = 0; i < length; i++) {
count += valueArray[i];
@@ -98,7 +94,7 @@ public class CountAggregationFunction implements AggregationFunction<Long, Long>
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
if (blockValSetMap.size() == 0) {
for (int i = 0; i < length; i++) {
int groupKey = groupKeyArray[i];
@@ -106,7 +102,7 @@ public class CountAggregationFunction implements AggregationFunction<Long, Long>
}
} else {
// Star-tree pre-aggregated values
- long[] valueArray = blockValSetMap.get(_column).getLongValuesSV();
+ long[] valueArray = blockValSetMap.get(STAR_TREE_COUNT_STAR_EXPRESSION).getLongValuesSV();
for (int i = 0; i < length; i++) {
int groupKey = groupKeyArray[i];
groupByResultHolder.setValueForKey(groupKey, groupByResultHolder.getDoubleResult(groupKey) + valueArray[i]);
@@ -116,7 +112,7 @@ public class CountAggregationFunction implements AggregationFunction<Long, Long>
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
if (blockValSetMap.size() == 0) {
for (int i = 0; i < length; i++) {
for (int groupKey : groupKeysArray[i]) {
@@ -125,7 +121,7 @@ public class CountAggregationFunction implements AggregationFunction<Long, Long>
}
} else {
// Star-tree pre-aggregated values
- long[] valueArray = blockValSetMap.get(_column).getLongValuesSV();
+ long[] valueArray = blockValSetMap.get(STAR_TREE_COUNT_STAR_EXPRESSION).getLongValuesSV();
for (int i = 0; i < length; i++) {
long value = valueArray[i];
for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java
index 29b540e..96afabe 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/CountMVAggregationFunction.java
@@ -29,16 +29,17 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
public class CountMVAggregationFunction extends CountAggregationFunction {
-
- private final List<TransformExpressionTree> _inputExpressions;
+ private final String _column;
+ private final TransformExpressionTree _expression;
/**
* Constructor for the class.
- * @param column Column name to aggregate on.
+ *
+ * @param column Column to aggregate on (could be column name or transform function).
*/
public CountMVAggregationFunction(String column) {
- super(column);
- _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(column));
+ _column = column;
+ _expression = TransformExpressionTree.compileToExpressionTree(column);
}
@Override
@@ -48,17 +49,17 @@ public class CountMVAggregationFunction extends CountAggregationFunction {
@Override
public String getColumnName() {
- return getType().getName() + "_" + _column;
+ return AggregationFunctionType.COUNTMV.getName() + "_" + _column;
}
@Override
public String getResultColumnName() {
- return getType().getName().toLowerCase() + "(" + _column + ")";
+ return AggregationFunctionType.COUNTMV.getName().toLowerCase() + "(" + _column + ")";
}
@Override
public List<TransformExpressionTree> getInputExpressions() {
- return _inputExpressions;
+ return Collections.singletonList(_expression);
}
@Override
@@ -68,8 +69,8 @@ public class CountMVAggregationFunction extends CountAggregationFunction {
@Override
public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- int[] valueArray = blockValSetMap.get(_column).getNumMVEntries();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ int[] valueArray = blockValSetMap.get(_expression).getNumMVEntries();
long count = 0L;
for (int i = 0; i < length; i++) {
count += valueArray[i];
@@ -79,8 +80,8 @@ public class CountMVAggregationFunction extends CountAggregationFunction {
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- int[] valueArray = blockValSetMap.get(_column).getNumMVEntries();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ int[] valueArray = blockValSetMap.get(_expression).getNumMVEntries();
for (int i = 0; i < length; i++) {
int groupKey = groupKeyArray[i];
groupByResultHolder.setValueForKey(groupKey, groupByResultHolder.getDoubleResult(groupKey) + valueArray[i]);
@@ -89,8 +90,8 @@ public class CountMVAggregationFunction extends CountAggregationFunction {
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- int[] valueArray = blockValSetMap.get(_column).getNumMVEntries();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ int[] valueArray = blockValSetMap.get(_expression).getNumMVEntries();
for (int i = 0; i < length; i++) {
int value = valueArray[i];
for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunction.java
index b09f028..8aceb29 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctAggregationFunction.java
@@ -37,7 +37,6 @@ import org.apache.pinot.core.query.aggregation.DistinctTable;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.util.GroupByUtils;
-import org.apache.pinot.pql.parsers.pql2.ast.IdentifierAstNode;
/**
@@ -79,12 +78,13 @@ public class DistinctAggregationFunction implements AggregationFunction<Distinct
@Override
public String getColumnName() {
- return getType().getName() + "_" + AggregationFunctionUtils.concatArgs(Arrays.asList(_columns));
+ return AggregationFunctionType.DISTINCT.getName() + "_" + AggregationFunctionUtils.concatArgs(_columns);
}
@Override
public String getResultColumnName() {
- return getType().getName().toLowerCase() + "(" + AggregationFunctionUtils.concatArgs(Arrays.asList(_columns)) + ")";
+ return AggregationFunctionType.DISTINCT.getName().toLowerCase() + "(" + AggregationFunctionUtils
+ .concatArgs(_columns) + ")";
}
@Override
@@ -104,23 +104,23 @@ public class DistinctAggregationFunction implements AggregationFunction<Distinct
@Override
public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- int numColumns = _columns.length;
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
int numBlockValSets = blockValSetMap.size();
- Preconditions.checkState(numBlockValSets == numColumns, "Size mismatch: numBlockValSets = %s, numColumns = %s",
- numBlockValSets, numColumns);
-
- DistinctTable distinctTable = aggregationResultHolder.getResult();
- BlockValSet[] blockValSets = new BlockValSet[numColumns];
-
- for (int i = 0; i < numColumns; i++) {
- blockValSets[i] = blockValSetMap.get(_columns[i]);
+ int numExpressions = _inputExpressions.size();
+ Preconditions
+ .checkState(numBlockValSets == numExpressions, "Size mismatch: numBlockValSets = %s, numExpressions = %s",
+ numBlockValSets, numExpressions);
+
+ BlockValSet[] blockValSets = new BlockValSet[numExpressions];
+ for (int i = 0; i < numExpressions; i++) {
+ blockValSets[i] = blockValSetMap.get(_inputExpressions.get(i));
}
+ DistinctTable distinctTable = aggregationResultHolder.getResult();
if (distinctTable == null) {
- ColumnDataType[] columnDataTypes = new ColumnDataType[numColumns];
- for (int i = 0; i < numColumns; i++) {
- columnDataTypes[i] = ColumnDataType.fromDataTypeSV(blockValSetMap.get(_columns[i]).getValueType());
+ ColumnDataType[] columnDataTypes = new ColumnDataType[numExpressions];
+ for (int i = 0; i < numExpressions; i++) {
+ columnDataTypes[i] = ColumnDataType.fromDataTypeSV(blockValSetMap.get(_inputExpressions.get(i)).getValueType());
}
DataSchema dataSchema = new DataSchema(_columns, columnDataTypes);
distinctTable = new DistinctTable(dataSchema, _orderBy, _capacity);
@@ -147,8 +147,7 @@ public class DistinctAggregationFunction implements AggregationFunction<Distinct
if (distinctTable != null) {
return distinctTable;
} else {
- int numColumns = _columns.length;
- ColumnDataType[] columnDataTypes = new ColumnDataType[numColumns];
+ ColumnDataType[] columnDataTypes = new ColumnDataType[_columns.length];
// NOTE: Use STRING for unknown type
Arrays.fill(columnDataTypes, ColumnDataType.STRING);
return new DistinctTable(new DataSchema(_columns, columnDataTypes), _orderBy, _capacity);
@@ -188,13 +187,13 @@ public class DistinctAggregationFunction implements AggregationFunction<Distinct
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
throw new UnsupportedOperationException("Operation not supported for DISTINCT aggregation function");
}
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
throw new UnsupportedOperationException("Operation not supported for DISTINCT aggregation function");
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java
index ebc3aaf..4fafb1d 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java
@@ -19,8 +19,6 @@
package org.apache.pinot.core.query.aggregation.function;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
-import java.util.Collections;
-import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -33,18 +31,10 @@ import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder
import org.apache.pinot.spi.data.FieldSpec;
-public class DistinctCountAggregationFunction implements AggregationFunction<IntOpenHashSet, Integer> {
+public class DistinctCountAggregationFunction extends BaseSingleInputAggregationFunction<IntOpenHashSet, Integer> {
- protected final String _column;
- private final List<TransformExpressionTree> _inputExpressions;
-
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public DistinctCountAggregationFunction(String column) {
- _column = column;
- _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+ super(column);
}
@Override
@@ -53,21 +43,6 @@ public class DistinctCountAggregationFunction implements AggregationFunction<Int
}
@Override
- public String getColumnName() {
- return getType().getName() + "_" + _column;
- }
-
- @Override
- public String getResultColumnName() {
- return getType().getName().toLowerCase() + "(" + _column + ")";
- }
-
- @Override
- public List<TransformExpressionTree> getInputExpressions() {
- return _inputExpressions;
- }
-
- @Override
public void accept(AggregationFunctionVisitorBase visitor) {
visitor.visit(this);
}
@@ -83,8 +58,9 @@ public class DistinctCountAggregationFunction implements AggregationFunction<Int
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
IntOpenHashSet valueSet = getValueSet(aggregationResultHolder);
FieldSpec.DataType valueType = blockValSet.getValueType();
@@ -126,8 +102,8 @@ public class DistinctCountAggregationFunction implements AggregationFunction<Int
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
FieldSpec.DataType valueType = blockValSet.getValueType();
switch (valueType) {
@@ -168,8 +144,8 @@ public class DistinctCountAggregationFunction implements AggregationFunction<Int
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
FieldSpec.DataType valueType = blockValSet.getValueType();
switch (valueType) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java
index f8371ea..88beff9 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLAggregationFunction.java
@@ -20,8 +20,6 @@ package org.apache.pinot.core.query.aggregation.function;
import com.clearspring.analytics.stream.cardinality.HyperLogLog;
import com.google.common.base.Preconditions;
-import java.util.Collections;
-import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -35,19 +33,11 @@ import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder
import org.apache.pinot.spi.data.FieldSpec.DataType;
-public class DistinctCountHLLAggregationFunction implements AggregationFunction<HyperLogLog, Long> {
- protected final String _column;
-
+public class DistinctCountHLLAggregationFunction extends BaseSingleInputAggregationFunction<HyperLogLog, Long> {
public static final int DEFAULT_LOG2M = 8;
- private final List<TransformExpressionTree> _inputExpressions;
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public DistinctCountHLLAggregationFunction(String column) {
- _column = column;
- _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+ super(column);
}
@Override
@@ -56,21 +46,6 @@ public class DistinctCountHLLAggregationFunction implements AggregationFunction<
}
@Override
- public String getColumnName() {
- return getType().getName() + "_" + _column;
- }
-
- @Override
- public String getResultColumnName() {
- return getType().getName().toLowerCase() + "(" + _column + ")";
- }
-
- @Override
- public List<TransformExpressionTree> getInputExpressions() {
- return _inputExpressions;
- }
-
- @Override
public void accept(AggregationFunctionVisitorBase visitor) {
visitor.visit(this);
}
@@ -86,8 +61,9 @@ public class DistinctCountHLLAggregationFunction implements AggregationFunction<
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
DataType valueType = blockValSet.getValueType();
if (valueType != DataType.BYTES) {
@@ -151,8 +127,8 @@ public class DistinctCountHLLAggregationFunction implements AggregationFunction<
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
DataType valueType = blockValSet.getValueType();
switch (valueType) {
@@ -211,8 +187,8 @@ public class DistinctCountHLLAggregationFunction implements AggregationFunction<
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
DataType valueType = blockValSet.getValueType();
switch (valueType) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLMVAggregationFunction.java
index eee8db4..1ff547c 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountHLLMVAggregationFunction.java
@@ -21,6 +21,7 @@ package org.apache.pinot.core.query.aggregation.function;
import com.clearspring.analytics.stream.cardinality.HyperLogLog;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -29,10 +30,6 @@ import org.apache.pinot.spi.data.FieldSpec.DataType;
public class DistinctCountHLLMVAggregationFunction extends DistinctCountHLLAggregationFunction {
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public DistinctCountHLLMVAggregationFunction(String column) {
super(column);
}
@@ -48,10 +45,11 @@ public class DistinctCountHLLMVAggregationFunction extends DistinctCountHLLAggre
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
HyperLogLog hyperLogLog = getDefaultHyperLogLog(aggregationResultHolder);
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
DataType valueType = blockValSet.getValueType();
switch (valueType) {
case INT:
@@ -102,8 +100,8 @@ public class DistinctCountHLLMVAggregationFunction extends DistinctCountHLLAggre
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
DataType valueType = blockValSet.getValueType();
switch (valueType) {
@@ -160,8 +158,8 @@ public class DistinctCountHLLMVAggregationFunction extends DistinctCountHLLAggre
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
DataType valueType = blockValSet.getValueType();
switch (valueType) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java
index 5b69532..910c8aa 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java
@@ -21,6 +21,7 @@ package org.apache.pinot.core.query.aggregation.function;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -29,10 +30,6 @@ import org.apache.pinot.spi.data.FieldSpec;
public class DistinctCountMVAggregationFunction extends DistinctCountAggregationFunction {
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public DistinctCountMVAggregationFunction(String column) {
super(column);
}
@@ -48,10 +45,11 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
IntOpenHashSet valueSet = getValueSet(aggregationResultHolder);
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
FieldSpec.DataType valueType = blockValSet.getValueType();
switch (valueType) {
case INT:
@@ -100,8 +98,8 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
FieldSpec.DataType valueType = blockValSet.getValueType();
switch (valueType) {
@@ -157,8 +155,8 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
FieldSpec.DataType valueType = blockValSet.getValueType();
switch (valueType) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLAggregationFunction.java
index fe946ca..40050f6 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLAggregationFunction.java
@@ -19,8 +19,6 @@
package org.apache.pinot.core.query.aggregation.function;
import com.clearspring.analytics.stream.cardinality.HyperLogLog;
-import java.util.Collections;
-import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -31,25 +29,17 @@ import org.apache.pinot.core.query.aggregation.function.customobject.SerializedH
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
-public class DistinctCountRawHLLAggregationFunction implements AggregationFunction<HyperLogLog, SerializedHLL> {
+public class DistinctCountRawHLLAggregationFunction extends BaseSingleInputAggregationFunction<HyperLogLog, SerializedHLL> {
private final DistinctCountHLLAggregationFunction _distinctCountHLLAggregationFunction;
- protected final String _column;
- private final List<TransformExpressionTree> _inputExpressions;
-
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public DistinctCountRawHLLAggregationFunction(String column) {
- this(new DistinctCountHLLAggregationFunction(column), column);
+ this(column, new DistinctCountHLLAggregationFunction(column));
}
- DistinctCountRawHLLAggregationFunction(DistinctCountHLLAggregationFunction distinctCountHLLAggregationFunction,
- String column) {
+ DistinctCountRawHLLAggregationFunction(String column,
+ DistinctCountHLLAggregationFunction distinctCountHLLAggregationFunction) {
+ super(column);
_distinctCountHLLAggregationFunction = distinctCountHLLAggregationFunction;
- _column = column;
- _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
}
@Override
@@ -58,21 +48,6 @@ public class DistinctCountRawHLLAggregationFunction implements AggregationFuncti
}
@Override
- public String getColumnName() {
- return getType().getName() + "_" + _column;
- }
-
- @Override
- public String getResultColumnName() {
- return getType().getName().toLowerCase() + "(" + _column + ")";
- }
-
- @Override
- public List<TransformExpressionTree> getInputExpressions() {
- return _inputExpressions;
- }
-
- @Override
public void accept(AggregationFunctionVisitorBase visitor) {
_distinctCountHLLAggregationFunction.accept(visitor);
}
@@ -89,19 +64,19 @@ public class DistinctCountRawHLLAggregationFunction implements AggregationFuncti
@Override
public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
_distinctCountHLLAggregationFunction.aggregate(length, aggregationResultHolder, blockValSetMap);
}
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
_distinctCountHLLAggregationFunction.aggregateGroupBySV(length, groupKeyArray, groupByResultHolder, blockValSetMap);
}
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
_distinctCountHLLAggregationFunction
.aggregateGroupByMV(length, groupKeysArray, groupByResultHolder, blockValSetMap);
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLMVAggregationFunction.java
index 2a1cf2a..3712599 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountRawHLLMVAggregationFunction.java
@@ -23,12 +23,8 @@ import org.apache.pinot.common.function.AggregationFunctionType;
public class DistinctCountRawHLLMVAggregationFunction extends DistinctCountRawHLLAggregationFunction {
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public DistinctCountRawHLLMVAggregationFunction(String column) {
- super(new DistinctCountHLLMVAggregationFunction(column), column);
+ super(column, new DistinctCountHLLMVAggregationFunction(column));
}
@Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
index 2dd03d3..b7463a0 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.java
@@ -64,6 +64,7 @@ import org.apache.pinot.sql.parsers.CalciteSqlParser;
public class DistinctCountThetaSketchAggregationFunction implements AggregationFunction<Map<String, Sketch>, Integer> {
private String _thetaSketchColumn;
+ private TransformExpressionTree _thetaSketchIdentifier;
private Set<String> _predicateStrings;
private Expression _postAggregationExpression;
private Set<PredicateInfo> _predicateInfoSet;
@@ -130,15 +131,15 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
@Override
public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
Map<String, Union> result = getDefaultResult(aggregationResultHolder, _predicateStrings);
- Sketch[] sketches = deserializeSketches(blockValSetMap.get(_thetaSketchColumn).getBytesValuesSV(), length);
+ Sketch[] sketches = deserializeSketches(blockValSetMap.get(_thetaSketchIdentifier).getBytesValuesSV(), length);
for (PredicateInfo predicateInfo : _predicateInfoSet) {
String predicate = predicateInfo.getStringVal();
- BlockValSet blockValSet = blockValSetMap.get(predicateInfo.getColumn());
+ BlockValSet blockValSet = blockValSetMap.get(predicateInfo.getExpression());
FieldSpec.DataType valueType = blockValSet.getValueType();
PredicateEvaluator predicateEvaluator = predicateInfo.getPredicateEvaluator(valueType);
@@ -181,13 +182,13 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- Sketch[] sketches = deserializeSketches(blockValSetMap.get(_thetaSketchColumn).getBytesValuesSV(), length);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ Sketch[] sketches = deserializeSketches(blockValSetMap.get(_thetaSketchIdentifier).getBytesValuesSV(), length);
for (PredicateInfo predicateInfo : _predicateInfoSet) {
String predicate = predicateInfo.getStringVal();
- BlockValSet blockValSet = blockValSetMap.get(predicateInfo.getColumn());
+ BlockValSet blockValSet = blockValSetMap.get(predicateInfo.getExpression());
FieldSpec.DataType valueType = blockValSet.getValueType();
PredicateEvaluator predicateEvaluator = predicateInfo.getPredicateEvaluator(valueType);
@@ -237,13 +238,13 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- Sketch[] sketches = deserializeSketches(blockValSetMap.get(_thetaSketchColumn).getBytesValuesSV(), length);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ Sketch[] sketches = deserializeSketches(blockValSetMap.get(_thetaSketchIdentifier).getBytesValuesSV(), length);
for (PredicateInfo predicateInfo : _predicateInfoSet) {
String predicate = predicateInfo.getStringVal();
- BlockValSet blockValSet = blockValSetMap.get(predicateInfo.getColumn());
+ BlockValSet blockValSet = blockValSetMap.get(predicateInfo.getExpression());
FieldSpec.DataType valueType = blockValSet.getValueType();
PredicateEvaluator predicateEvaluator = predicateInfo.getPredicateEvaluator(valueType);
@@ -425,10 +426,12 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
// Initialize the Theta-Sketch Column.
_thetaSketchColumn = arguments.get(0);
+ _thetaSketchIdentifier =
+ new TransformExpressionTree(TransformExpressionTree.ExpressionType.IDENTIFIER, _thetaSketchColumn, null);
// Initialize input expressions. It is expected they are covered between the theta-sketch column and the predicates.
_inputExpressions = new ArrayList<>();
- _inputExpressions.add(TransformExpressionTree.compileToExpressionTree(_thetaSketchColumn));
+ _inputExpressions.add(_thetaSketchIdentifier);
// Initialize thetaSketchParams
String paramsString = arguments.get(1);
@@ -443,14 +446,18 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
if (predicatesSpecified) {
for (String predicateString : _predicateStrings) {
+ // FIXME: Standardize predicate string?
+
Expression expression = CalciteSqlParser.compileToExpression(predicateString);
// TODO: Add support for complex predicates with AND/OR.
String filterColumn = ParserUtils.getFilterColumn(expression);
Predicate predicate = Predicate
.newPredicate(ParserUtils.getFilterType(expression), filterColumn, ParserUtils.getFilterValues(expression));
+ TransformExpressionTree filterExpression =
+ new TransformExpressionTree(TransformExpressionTree.ExpressionType.IDENTIFIER, filterColumn, null);
- _predicateInfoSet.add(new PredicateInfo(predicateString, filterColumn, predicate));
+ _predicateInfoSet.add(new PredicateInfo(predicateString, filterExpression, predicate));
_expressionMap.put(expression, predicateString);
_inputExpressions.add(new TransformExpressionTree(new IdentifierAstNode(filterColumn)));
}
@@ -461,12 +468,14 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
String filterColumn = ParserUtils.getFilterColumn(predicateExpression);
Predicate predicate = Predicate.newPredicate(ParserUtils.getFilterType(predicateExpression), filterColumn,
ParserUtils.getFilterValues(predicateExpression));
+ TransformExpressionTree filterExpression =
+ new TransformExpressionTree(TransformExpressionTree.ExpressionType.IDENTIFIER, filterColumn, null);
String predicateString = ParserUtils.standardizeExpression(predicateExpression, false);
_predicateStrings.add(predicateString);
- _predicateInfoSet.add(new PredicateInfo(predicateString, filterColumn, predicate));
+ _predicateInfoSet.add(new PredicateInfo(predicateString, filterExpression, predicate));
_expressionMap.put(predicateExpression, predicateString);
- _inputExpressions.add(new TransformExpressionTree(new IdentifierAstNode(filterColumn)));
+ _inputExpressions.add(filterExpression);
}
}
}
@@ -567,15 +576,15 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
*
*/
private static class PredicateInfo {
-
private final String _stringVal;
- private final String _column; // LHS
+ private final TransformExpressionTree _expression; // LHS
+ // FIXME: Predicate does not have equals() and hashCode() implemented
private final Predicate _predicate;
private PredicateEvaluator _predicateEvaluator;
- private PredicateInfo(String stringVal, String column, Predicate predicate) {
+ private PredicateInfo(String stringVal, TransformExpressionTree expression, Predicate predicate) {
_stringVal = stringVal;
- _column = column;
+ _expression = expression;
_predicate = predicate;
_predicateEvaluator = null; // Initialized lazily
}
@@ -584,8 +593,8 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
return _stringVal;
}
- public String getColumn() {
- return _column;
+ public TransformExpressionTree getExpression() {
+ return _expression;
}
public Predicate getPredicate() {
@@ -625,13 +634,13 @@ public class DistinctCountThetaSketchAggregationFunction implements AggregationF
}
PredicateInfo that = (PredicateInfo) o;
- return Objects.equals(_stringVal, that._stringVal) && Objects.equals(_column, that._column) && Objects
+ return Objects.equals(_stringVal, that._stringVal) && Objects.equals(_expression, that._expression) && Objects
.equals(_predicate, that._predicate);
}
@Override
public int hashCode() {
- return Objects.hash(_stringVal, _column, _predicate);
+ return Objects.hash(_stringVal, _expression, _predicate);
}
}
-}
\ No newline at end of file
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java
index 12d92cc..0f7022e 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/FastHLLAggregationFunction.java
@@ -20,8 +20,6 @@ package org.apache.pinot.core.query.aggregation.function;
import com.clearspring.analytics.stream.cardinality.HyperLogLog;
import com.google.common.base.Preconditions;
-import java.util.Collections;
-import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -38,20 +36,12 @@ import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder
* Use {@link DistinctCountHLLAggregationFunction} on byte[] for serialized HyperLogLog.
*/
@Deprecated
-public class FastHLLAggregationFunction implements AggregationFunction<HyperLogLog, Long> {
+public class FastHLLAggregationFunction extends BaseSingleInputAggregationFunction<HyperLogLog, Long> {
public static final int DEFAULT_LOG2M = 8;
private static final int BYTE_TO_CHAR_OFFSET = 129;
- private final String _column;
- private final List<TransformExpressionTree> _inputExpressions;
-
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public FastHLLAggregationFunction(String column) {
- _column = column;
- _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+ super(column);
}
@Override
@@ -60,21 +50,6 @@ public class FastHLLAggregationFunction implements AggregationFunction<HyperLogL
}
@Override
- public String getColumnName() {
- return getType().getName() + "_" + _column;
- }
-
- @Override
- public String getResultColumnName() {
- return getType().getName().toLowerCase() + "(" + _column + ")";
- }
-
- @Override
- public List<TransformExpressionTree> getInputExpressions() {
- return _inputExpressions;
- }
-
- @Override
public void accept(AggregationFunctionVisitorBase visitor) {
visitor.visit(this);
}
@@ -90,8 +65,9 @@ public class FastHLLAggregationFunction implements AggregationFunction<HyperLogL
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
- String[] values = blockValSetMap.get(_column).getStringValuesSV();
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ String[] values = blockValSetMap.get(_expression).getStringValuesSV();
try {
HyperLogLog hyperLogLog = aggregationResultHolder.getResult();
if (hyperLogLog != null) {
@@ -112,8 +88,8 @@ public class FastHLLAggregationFunction implements AggregationFunction<HyperLogL
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- String[] values = blockValSetMap.get(_column).getStringValuesSV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ String[] values = blockValSetMap.get(_expression).getStringValuesSV();
try {
for (int i = 0; i < length; i++) {
HyperLogLog value = convertStringToHLL(values[i]);
@@ -132,8 +108,8 @@ public class FastHLLAggregationFunction implements AggregationFunction<HyperLogL
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- String[] values = blockValSetMap.get(_column).getStringValuesSV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ String[] values = blockValSetMap.get(_expression).getStringValuesSV();
try {
for (int i = 0; i < length; i++) {
HyperLogLog value = convertStringToHLL(values[i]);
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java
index a9b3ff3..5d141b3 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxAggregationFunction.java
@@ -18,8 +18,6 @@
*/
package org.apache.pinot.core.query.aggregation.function;
-import java.util.Collections;
-import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -31,19 +29,11 @@ import org.apache.pinot.core.query.aggregation.groupby.DoubleGroupByResultHolder
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
-public class MaxAggregationFunction implements AggregationFunction<Double, Double> {
+public class MaxAggregationFunction extends BaseSingleInputAggregationFunction<Double, Double> {
private static final double DEFAULT_INITIAL_VALUE = Double.NEGATIVE_INFINITY;
- protected final String _column;
- private final List<TransformExpressionTree> _inputExpressions;
-
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public MaxAggregationFunction(String column) {
- _column = column;
- _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+ super(column);
}
@Override
@@ -52,21 +42,6 @@ public class MaxAggregationFunction implements AggregationFunction<Double, Doubl
}
@Override
- public String getColumnName() {
- return getType().getName() + "_" + _column;
- }
-
- @Override
- public String getResultColumnName() {
- return getType().getName().toLowerCase() + "(" + _column + ")";
- }
-
- @Override
- public List<TransformExpressionTree> getInputExpressions() {
- return _inputExpressions;
- }
-
- @Override
public void accept(AggregationFunctionVisitorBase visitor) {
visitor.visit(this);
}
@@ -82,8 +57,9 @@ public class MaxAggregationFunction implements AggregationFunction<Double, Doubl
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
double max = aggregationResultHolder.getDoubleResult();
for (int i = 0; i < length; i++) {
double value = valueArray[i];
@@ -96,8 +72,8 @@ public class MaxAggregationFunction implements AggregationFunction<Double, Doubl
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
for (int i = 0; i < length; i++) {
double value = valueArray[i];
int groupKey = groupKeyArray[i];
@@ -109,8 +85,8 @@ public class MaxAggregationFunction implements AggregationFunction<Double, Doubl
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
for (int i = 0; i < length; i++) {
double value = valueArray[i];
for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxMVAggregationFunction.java
index 3083f87..434bcea 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MaxMVAggregationFunction.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.query.aggregation.function;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -27,10 +28,6 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
public class MaxMVAggregationFunction extends MaxAggregationFunction {
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public MaxMVAggregationFunction(String column) {
super(column);
}
@@ -46,8 +43,9 @@ public class MaxMVAggregationFunction extends MaxAggregationFunction {
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
double max = aggregationResultHolder.getDoubleResult();
for (int i = 0; i < length; i++) {
for (double value : valuesArray[i]) {
@@ -61,8 +59,8 @@ public class MaxMVAggregationFunction extends MaxAggregationFunction {
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
int groupKey = groupKeyArray[i];
double max = groupByResultHolder.getDoubleResult(groupKey);
@@ -77,8 +75,8 @@ public class MaxMVAggregationFunction extends MaxAggregationFunction {
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
double[] values = valuesArray[i];
for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java
index ddf451a..68aca61 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinAggregationFunction.java
@@ -18,8 +18,6 @@
*/
package org.apache.pinot.core.query.aggregation.function;
-import java.util.Collections;
-import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -31,19 +29,11 @@ import org.apache.pinot.core.query.aggregation.groupby.DoubleGroupByResultHolder
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
-public class MinAggregationFunction implements AggregationFunction<Double, Double> {
+public class MinAggregationFunction extends BaseSingleInputAggregationFunction<Double, Double> {
private static final double DEFAULT_VALUE = Double.POSITIVE_INFINITY;
- protected final String _column;
- private final List<TransformExpressionTree> _inputExpressions;
-
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public MinAggregationFunction(String column) {
- _column = column;
- _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+ super(column);
}
@Override
@@ -52,21 +42,6 @@ public class MinAggregationFunction implements AggregationFunction<Double, Doubl
}
@Override
- public String getColumnName() {
- return getType().getName() + "_" + _column;
- }
-
- @Override
- public String getResultColumnName() {
- return getType().getName().toLowerCase() + "(" + _column + ")";
- }
-
- @Override
- public List<TransformExpressionTree> getInputExpressions() {
- return _inputExpressions;
- }
-
- @Override
public void accept(AggregationFunctionVisitorBase visitor) {
visitor.visit(this);
}
@@ -82,8 +57,9 @@ public class MinAggregationFunction implements AggregationFunction<Double, Doubl
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
double min = aggregationResultHolder.getDoubleResult();
for (int i = 0; i < length; i++) {
double value = valueArray[i];
@@ -96,8 +72,8 @@ public class MinAggregationFunction implements AggregationFunction<Double, Doubl
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
for (int i = 0; i < length; i++) {
double value = valueArray[i];
int groupKey = groupKeyArray[i];
@@ -109,8 +85,8 @@ public class MinAggregationFunction implements AggregationFunction<Double, Doubl
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
for (int i = 0; i < length; i++) {
double value = valueArray[i];
for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMVAggregationFunction.java
index f7b3141..e97a405 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMVAggregationFunction.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.query.aggregation.function;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -27,10 +28,6 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
public class MinMVAggregationFunction extends MinAggregationFunction {
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public MinMVAggregationFunction(String column) {
super(column);
}
@@ -46,8 +43,9 @@ public class MinMVAggregationFunction extends MinAggregationFunction {
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
double min = aggregationResultHolder.getDoubleResult();
for (int i = 0; i < length; i++) {
for (double value : valuesArray[i]) {
@@ -61,8 +59,8 @@ public class MinMVAggregationFunction extends MinAggregationFunction {
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
int groupKey = groupKeyArray[i];
double min = groupByResultHolder.getDoubleResult(groupKey);
@@ -77,8 +75,8 @@ public class MinMVAggregationFunction extends MinAggregationFunction {
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
double[] values = valuesArray[i];
for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java
index ca4ec10..53aa53a 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeAggregationFunction.java
@@ -18,8 +18,6 @@
*/
package org.apache.pinot.core.query.aggregation.function;
-import java.util.Collections;
-import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -34,18 +32,10 @@ import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder
import org.apache.pinot.spi.data.FieldSpec.DataType;
-public class MinMaxRangeAggregationFunction implements AggregationFunction<MinMaxRangePair, Double> {
+public class MinMaxRangeAggregationFunction extends BaseSingleInputAggregationFunction<MinMaxRangePair, Double> {
- protected final String _column;
- private final List<TransformExpressionTree> _inputExpressions;
-
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public MinMaxRangeAggregationFunction(String column) {
- _column = column;
- _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+ super(column);
}
@Override
@@ -54,21 +44,6 @@ public class MinMaxRangeAggregationFunction implements AggregationFunction<MinMa
}
@Override
- public String getColumnName() {
- return getType().getName() + "_" + _column;
- }
-
- @Override
- public String getResultColumnName() {
- return getType().getName().toLowerCase() + "(" + _column + ")";
- }
-
- @Override
- public List<TransformExpressionTree> getInputExpressions() {
- return _inputExpressions;
- }
-
- @Override
public void accept(AggregationFunctionVisitorBase visitor) {
visitor.visit(this);
}
@@ -85,11 +60,11 @@ public class MinMaxRangeAggregationFunction implements AggregationFunction<MinMa
@Override
public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
double min = Double.POSITIVE_INFINITY;
double max = Double.NEGATIVE_INFINITY;
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
@@ -130,8 +105,8 @@ public class MinMaxRangeAggregationFunction implements AggregationFunction<MinMa
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
@@ -150,8 +125,8 @@ public class MinMaxRangeAggregationFunction implements AggregationFunction<MinMa
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeMVAggregationFunction.java
index 5a18190..d88a098 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/MinMaxRangeMVAggregationFunction.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.query.aggregation.function;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -27,10 +28,6 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
public class MinMaxRangeMVAggregationFunction extends MinMaxRangeAggregationFunction {
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public MinMaxRangeMVAggregationFunction(String column) {
super(column);
}
@@ -46,8 +43,9 @@ public class MinMaxRangeMVAggregationFunction extends MinMaxRangeAggregationFunc
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
double min = Double.POSITIVE_INFINITY;
double max = Double.NEGATIVE_INFINITY;
for (int i = 0; i < length; i++) {
@@ -66,8 +64,8 @@ public class MinMaxRangeMVAggregationFunction extends MinMaxRangeAggregationFunc
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
aggregateOnGroupKey(groupKeyArray[i], groupByResultHolder, valuesArray[i]);
}
@@ -75,8 +73,8 @@ public class MinMaxRangeMVAggregationFunction extends MinMaxRangeAggregationFunc
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
double[] values = valuesArray[i];
for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java
index 580e7f3..d779d71 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java
@@ -18,11 +18,8 @@
*/
package org.apache.pinot.core.query.aggregation.function;
-import com.google.common.base.Preconditions;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -34,29 +31,14 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
-public class PercentileAggregationFunction implements AggregationFunction<DoubleArrayList, Double> {
+public class PercentileAggregationFunction extends BaseSingleInputAggregationFunction<DoubleArrayList, Double> {
private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY;
protected final int _percentile;
- protected final String _column;
- private final List<TransformExpressionTree> _inputExpressions;
- /**
- * Constructor for the class.
- *
- * @param arguments List of arguments.
- * <ul>
- * <li> Arg 0: Column name to aggregate.</li>
- * <li> Arg 1: Percentile to compute. </li>
- * </ul>
- */
- public PercentileAggregationFunction(List<String> arguments) {
- int numArgs = arguments.size();
- Preconditions.checkArgument(numArgs == 2, getType() + " expects two argument, got: " + numArgs);
-
- _column = arguments.get(0);
- _percentile = AggregationFunctionUtils.parsePercentile(arguments.get(1));
- _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+ public PercentileAggregationFunction(String column, int percentile) {
+ super(column);
+ _percentile = percentile;
}
@Override
@@ -66,17 +48,12 @@ public class PercentileAggregationFunction implements AggregationFunction<Double
@Override
public String getColumnName() {
- return getType().getName() + _percentile + "_" + _column;
+ return AggregationFunctionType.PERCENTILE.getName() + _percentile + "_" + _column;
}
@Override
public String getResultColumnName() {
- return getType().getName().toLowerCase() + _percentile + "(" + _column + ")";
- }
-
- @Override
- public List<TransformExpressionTree> getInputExpressions() {
- return _inputExpressions;
+ return AggregationFunctionType.PERCENTILE.getName().toLowerCase() + _percentile + "(" + _column + ")";
}
@Override
@@ -95,9 +72,10 @@ public class PercentileAggregationFunction implements AggregationFunction<Double
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
DoubleArrayList valueList = getValueList(aggregationResultHolder);
- double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+ double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
for (int i = 0; i < length; i++) {
valueList.add(valueArray[i]);
}
@@ -105,8 +83,8 @@ public class PercentileAggregationFunction implements AggregationFunction<Double
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
for (int i = 0; i < length; i++) {
DoubleArrayList valueList = getValueList(groupByResultHolder, groupKeyArray[i]);
valueList.add(valueArray[i]);
@@ -115,8 +93,8 @@ public class PercentileAggregationFunction implements AggregationFunction<Double
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
for (int i = 0; i < length; i++) {
double value = valueArray[i];
for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java
index 00cfd1d..5a89731 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java
@@ -18,9 +18,6 @@
*/
package org.apache.pinot.core.query.aggregation.function;
-import com.google.common.base.Preconditions;
-import java.util.Collections;
-import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -35,29 +32,14 @@ import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder
import org.apache.pinot.spi.data.FieldSpec.DataType;
-public class PercentileEstAggregationFunction implements AggregationFunction<QuantileDigest, Long> {
+public class PercentileEstAggregationFunction extends BaseSingleInputAggregationFunction<QuantileDigest, Long> {
public static final double DEFAULT_MAX_ERROR = 0.05;
protected final int _percentile;
- protected final String _column;
- private final List<TransformExpressionTree> _inputExpressions;
- /**
- * Constructor for the class.
- *
- * @param arguments List of arguments.
- * <ul>
- * <li> Arg 0: Column name to aggregate.</li>
- * <li> Arg 1: Percentile to compute. </li>
- * </ul>
- */
- public PercentileEstAggregationFunction(List<String> arguments) {
- int numArgs = arguments.size();
- Preconditions.checkArgument(numArgs == 2, getType() + " expects two argument, got: " + numArgs);
-
- _column = arguments.get(0);
- _percentile = AggregationFunctionUtils.parsePercentile(arguments.get(1));
- _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+ public PercentileEstAggregationFunction(String column, int percentile) {
+ super(column);
+ _percentile = percentile;
}
@Override
@@ -76,11 +58,6 @@ public class PercentileEstAggregationFunction implements AggregationFunction<Qua
}
@Override
- public List<TransformExpressionTree> getInputExpressions() {
- return _inputExpressions;
- }
-
- @Override
public void accept(AggregationFunctionVisitorBase visitor) {
visitor.visit(this);
}
@@ -97,8 +74,8 @@ public class PercentileEstAggregationFunction implements AggregationFunction<Qua
@Override
public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
long[] longValues = blockValSet.getLongValuesSV();
QuantileDigest quantileDigest = getDefaultQuantileDigest(aggregationResultHolder);
@@ -125,8 +102,8 @@ public class PercentileEstAggregationFunction implements AggregationFunction<Qua
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
long[] longValues = blockValSet.getLongValuesSV();
for (int i = 0; i < length; i++) {
@@ -150,8 +127,8 @@ public class PercentileEstAggregationFunction implements AggregationFunction<Qua
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
long[] longValues = blockValSet.getLongValuesSV();
for (int i = 0; i < length; i++) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java
index 8809d42..bfc44b4 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java
@@ -18,9 +18,9 @@
*/
package org.apache.pinot.core.query.aggregation.function;
-import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.function.customobject.QuantileDigest;
@@ -29,17 +29,8 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
public class PercentileEstMVAggregationFunction extends PercentileEstAggregationFunction {
- /**
- * Constructor for the class.
- *
- * @param arguments List of arguments.
- * <ul>
- * <li> Arg 0: Column name to aggregate.</li>
- * <li> Arg 1: Percentile to compute. </li>
- * </ul>
- */
- public PercentileEstMVAggregationFunction(List<String> arguments) {
- super(arguments);
+ public PercentileEstMVAggregationFunction(String column, int percentile) {
+ super(column, percentile);
}
@Override
@@ -63,8 +54,9 @@ public class PercentileEstMVAggregationFunction extends PercentileEstAggregation
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
- long[][] valuesArray = blockValSetMap.get(_column).getLongValuesMV();
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ long[][] valuesArray = blockValSetMap.get(_expression).getLongValuesMV();
QuantileDigest quantileDigest = getDefaultQuantileDigest(aggregationResultHolder);
for (int i = 0; i < length; i++) {
for (long value : valuesArray[i]) {
@@ -75,8 +67,8 @@ public class PercentileEstMVAggregationFunction extends PercentileEstAggregation
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- long[][] valuesArray = blockValSetMap.get(_column).getLongValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ long[][] valuesArray = blockValSetMap.get(_expression).getLongValuesMV();
for (int i = 0; i < length; i++) {
QuantileDigest quantileDigest = getDefaultQuantileDigest(groupByResultHolder, groupKeyArray[i]);
for (long value : valuesArray[i]) {
@@ -87,8 +79,8 @@ public class PercentileEstMVAggregationFunction extends PercentileEstAggregation
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- long[][] valuesArray = blockValSetMap.get(_column).getLongValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ long[][] valuesArray = blockValSetMap.get(_expression).getLongValuesMV();
for (int i = 0; i < length; i++) {
long[] values = valuesArray[i];
for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java
index fba5f02..b02af5a 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java
@@ -19,9 +19,9 @@
package org.apache.pinot.core.query.aggregation.function;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
-import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -29,17 +29,8 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
public class PercentileMVAggregationFunction extends PercentileAggregationFunction {
- /**
- * Constructor for the class.
- *
- * @param arguments List of arguments.
- * <ul>
- * <li> Arg 0: Column name to aggregate.</li>
- * <li> Arg 1: Percentile to compute. </li>
- * </ul>
- */
- public PercentileMVAggregationFunction(List<String> arguments) {
- super(arguments);
+ public PercentileMVAggregationFunction(String column, int percentile) {
+ super(column, percentile);
}
@Override
@@ -63,8 +54,9 @@ public class PercentileMVAggregationFunction extends PercentileAggregationFuncti
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
DoubleArrayList valueList = getValueList(aggregationResultHolder);
for (int i = 0; i < length; i++) {
for (double value : valuesArray[i]) {
@@ -75,8 +67,8 @@ public class PercentileMVAggregationFunction extends PercentileAggregationFuncti
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
DoubleArrayList valueList = getValueList(groupByResultHolder, groupKeyArray[i]);
for (double value : valuesArray[i]) {
@@ -87,8 +79,8 @@ public class PercentileMVAggregationFunction extends PercentileAggregationFuncti
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
double[] values = valuesArray[i];
for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java
index faa9d71..bbe76dd 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java
@@ -18,10 +18,7 @@
*/
package org.apache.pinot.core.query.aggregation.function;
-import com.google.common.base.Preconditions;
import com.tdunning.math.stats.TDigest;
-import java.util.Collections;
-import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -38,28 +35,14 @@ import org.apache.pinot.spi.data.FieldSpec.DataType;
/**
* TDigest based Percentile aggregation function.
*/
-public class PercentileTDigestAggregationFunction implements AggregationFunction<TDigest, Double> {
+public class PercentileTDigestAggregationFunction extends BaseSingleInputAggregationFunction<TDigest, Double> {
public static final int DEFAULT_TDIGEST_COMPRESSION = 100;
protected final int _percentile;
- protected final String _column;
- private final List<TransformExpressionTree> _inputExpressions;
- /**
- * Constructor for the class.
- *
- * @param arguments List of arguments.
- * <ul>
- * <li> Arg 0: Column name to aggregate.</li>
- * <li> Arg 1: Percentile to compute. </li>
- * </ul>
- */
- public PercentileTDigestAggregationFunction(List<String> arguments) {
- int numArgs = arguments.size();
- Preconditions.checkArgument(numArgs == 2, getType() + " expects two argument, got: " + numArgs);
- _column = arguments.get(0);
- _percentile = AggregationFunctionUtils.parsePercentile(arguments.get(1));
- _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+ public PercentileTDigestAggregationFunction(String column, int percentile) {
+ super(column);
+ _percentile = percentile;
}
@Override
@@ -78,11 +61,6 @@ public class PercentileTDigestAggregationFunction implements AggregationFunction
}
@Override
- public List<TransformExpressionTree> getInputExpressions() {
- return _inputExpressions;
- }
-
- @Override
public void accept(AggregationFunctionVisitorBase visitor) {
visitor.visit(this);
}
@@ -98,8 +76,9 @@ public class PercentileTDigestAggregationFunction implements AggregationFunction
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
TDigest tDigest = getDefaultTDigest(aggregationResultHolder);
@@ -126,8 +105,8 @@ public class PercentileTDigestAggregationFunction implements AggregationFunction
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
@@ -151,8 +130,8 @@ public class PercentileTDigestAggregationFunction implements AggregationFunction
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- BlockValSet blockValSet = blockValSetMap.get(_column);
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
double[] doubleValues = blockValSet.getDoubleValuesSV();
for (int i = 0; i < length; i++) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java
index d71a7fe..01f4039 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java
@@ -19,9 +19,9 @@
package org.apache.pinot.core.query.aggregation.function;
import com.tdunning.math.stats.TDigest;
-import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -29,17 +29,8 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
public class PercentileTDigestMVAggregationFunction extends PercentileTDigestAggregationFunction {
- /**
- * Constructor for the class.
- *
- * @param arguments List of arguments.
- * <ul>
- * <li> Arg 0: Column name to aggregate.</li>
- * <li> Arg 1: Percentile to compute. </li>
- * </ul>
- */
- public PercentileTDigestMVAggregationFunction(List<String> arguments) {
- super(arguments);
+ public PercentileTDigestMVAggregationFunction(String column, int percentile) {
+ super(column, percentile);
}
@Override
@@ -63,8 +54,9 @@ public class PercentileTDigestMVAggregationFunction extends PercentileTDigestAgg
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
TDigest tDigest = getDefaultTDigest(aggregationResultHolder);
for (int i = 0; i < length; i++) {
for (double value : valuesArray[i]) {
@@ -75,8 +67,8 @@ public class PercentileTDigestMVAggregationFunction extends PercentileTDigestAgg
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
TDigest tDigest = getDefaultTDigest(groupByResultHolder, groupKeyArray[i]);
for (double value : valuesArray[i]) {
@@ -87,8 +79,8 @@ public class PercentileTDigestMVAggregationFunction extends PercentileTDigestAgg
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
double[] values = valuesArray[i];
for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java
index 2416919..7e71c21 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumAggregationFunction.java
@@ -18,8 +18,6 @@
*/
package org.apache.pinot.core.query.aggregation.function;
-import java.util.Collections;
-import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -31,18 +29,11 @@ import org.apache.pinot.core.query.aggregation.groupby.DoubleGroupByResultHolder
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
-public class SumAggregationFunction implements AggregationFunction<Double, Double> {
+public class SumAggregationFunction extends BaseSingleInputAggregationFunction<Double, Double> {
private static final double DEFAULT_VALUE = 0.0;
- protected final String _column;
- private final List<TransformExpressionTree> _inputExpressions;
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public SumAggregationFunction(String column) {
- _column = column;
- _inputExpressions = Collections.singletonList(TransformExpressionTree.compileToExpressionTree(_column));
+ super(column);
}
@Override
@@ -51,21 +42,6 @@ public class SumAggregationFunction implements AggregationFunction<Double, Doubl
}
@Override
- public String getColumnName() {
- return getType().getName() + "_" + _column;
- }
-
- @Override
- public String getResultColumnName() {
- return getType().getName().toLowerCase() + "(" + _column + ")";
- }
-
- @Override
- public List<TransformExpressionTree> getInputExpressions() {
- return _inputExpressions;
- }
-
- @Override
public void accept(AggregationFunctionVisitorBase visitor) {
visitor.visit(this);
}
@@ -81,8 +57,9 @@ public class SumAggregationFunction implements AggregationFunction<Double, Doubl
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
double sum = aggregationResultHolder.getDoubleResult();
for (int i = 0; i < length; i++) {
sum += valueArray[i];
@@ -92,8 +69,8 @@ public class SumAggregationFunction implements AggregationFunction<Double, Doubl
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
for (int i = 0; i < length; i++) {
int groupKey = groupKeyArray[i];
groupByResultHolder.setValueForKey(groupKey, groupByResultHolder.getDoubleResult(groupKey) + valueArray[i]);
@@ -102,8 +79,8 @@ public class SumAggregationFunction implements AggregationFunction<Double, Doubl
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_column).getDoubleValuesSV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
for (int i = 0; i < length; i++) {
double value = valueArray[i];
for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumMVAggregationFunction.java
index 8c468a7..7d7f9a0 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/SumMVAggregationFunction.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.query.aggregation.function;
import java.util.Map;
import org.apache.pinot.common.function.AggregationFunctionType;
+import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
@@ -27,10 +28,6 @@ import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
public class SumMVAggregationFunction extends SumAggregationFunction {
- /**
- * Constructor for the class.
- * @param column Column name to aggregate on.
- */
public SumMVAggregationFunction(String column) {
super(column);
}
@@ -46,8 +43,9 @@ public class SumMVAggregationFunction extends SumAggregationFunction {
}
@Override
- public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
double sum = aggregationResultHolder.getDoubleResult();
for (int i = 0; i < length; i++) {
for (double value : valuesArray[i]) {
@@ -59,8 +57,8 @@ public class SumMVAggregationFunction extends SumAggregationFunction {
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
int groupKey = groupKeyArray[i];
double sum = groupByResultHolder.getDoubleResult(groupKey);
@@ -73,8 +71,8 @@ public class SumMVAggregationFunction extends SumAggregationFunction {
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
- Map<String, BlockValSet> blockValSetMap) {
- double[][] valuesArray = blockValSetMap.get(_column).getDoubleValuesMV();
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap) {
+ double[][] valuesArray = blockValSetMap.get(_expression).getDoubleValuesMV();
for (int i = 0; i < length; i++) {
double[] values = valuesArray[i];
for (int groupKey : groupKeysArray[i]) {
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java
index dea2bd8..a2d561a 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/DefaultGroupByExecutor.java
@@ -18,21 +18,15 @@
*/
package org.apache.pinot.core.query.aggregation.groupby;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
import java.util.Map;
-import javax.annotation.Nonnull;
-import org.apache.pinot.common.function.AggregationFunctionType;
-import org.apache.pinot.common.request.GroupBy;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.operator.blocks.TransformBlock;
import org.apache.pinot.core.operator.transform.TransformOperator;
import org.apache.pinot.core.operator.transform.TransformResultMetadata;
import org.apache.pinot.core.plan.DocIdSetPlanNode;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
/**
@@ -51,62 +45,39 @@ public class DefaultGroupByExecutor implements GroupByExecutor {
private static final ThreadLocal<int[][]> THREAD_LOCAL_MV_GROUP_KEYS =
ThreadLocal.withInitial(() -> new int[DocIdSetPlanNode.MAX_DOC_PER_CALL][]);
- protected final int _numFunctions;
- protected final AggregationFunction[] _functions;
- protected final TransformExpressionTree[][] _aggregationExpressions;
+ protected final AggregationFunction[] _aggregationFunctions;
protected final GroupKeyGenerator _groupKeyGenerator;
- protected final GroupByResultHolder[] _resultHolders;
+ protected final GroupByResultHolder[] _groupByResultHolders;
protected final boolean _hasMVGroupByExpression;
- protected final boolean _hasNoDictionaryGroupByExpression;
protected final int[] _svGroupKeys;
protected final int[][] _mvGroupKeys;
/**
* Constructor for the class.
*
- * @param functionContexts Array of aggregation functions
- * @param groupBy Group by from broker request
+ * @param aggregationFunctions Array of aggregation functions
+ * @param groupByExpressions Array of group-by expressions
* @param maxInitialResultHolderCapacity Maximum initial capacity for the result holder
* @param numGroupsLimit Limit on number of aggregation groups returned in the result
* @param transformOperator Transform operator
*/
- public DefaultGroupByExecutor(@Nonnull AggregationFunctionContext[] functionContexts, @Nonnull GroupBy groupBy,
- int maxInitialResultHolderCapacity, int numGroupsLimit, @Nonnull TransformOperator transformOperator) {
- // Initialize aggregation functions and expressions
- _numFunctions = functionContexts.length;
- _functions = new AggregationFunction[_numFunctions];
- _aggregationExpressions = new TransformExpressionTree[_numFunctions][];
+ public DefaultGroupByExecutor(AggregationFunction[] aggregationFunctions,
+ TransformExpressionTree[] groupByExpressions, int maxInitialResultHolderCapacity, int numGroupsLimit,
+ TransformOperator transformOperator) {
+ _aggregationFunctions = aggregationFunctions;
- for (int i = 0; i < _numFunctions; i++) {
- AggregationFunction function = functionContexts[i].getAggregationFunction();
- _functions[i] = function;
-
- if (function.getType() != AggregationFunctionType.COUNT) {
- List<String> expressions = functionContexts[i].getExpressions();
-
- List<TransformExpressionTree> inputExpressions = function.getInputExpressions();
- _aggregationExpressions[i] = inputExpressions.toArray(new TransformExpressionTree[0]);
- }
- }
-
- // Initialize group-by expressions
- List<String> groupByExpressionStrings = groupBy.getExpressions();
- int numGroupByExpressions = groupByExpressionStrings.size();
boolean hasMVGroupByExpression = false;
boolean hasNoDictionaryGroupByExpression = false;
- TransformExpressionTree[] groupByExpressions = new TransformExpressionTree[numGroupByExpressions];
- for (int i = 0; i < numGroupByExpressions; i++) {
- groupByExpressions[i] = TransformExpressionTree.compileToExpressionTree(groupByExpressionStrings.get(i));
- TransformResultMetadata transformResultMetadata = transformOperator.getResultMetadata(groupByExpressions[i]);
+ for (TransformExpressionTree groupByExpression : groupByExpressions) {
+ TransformResultMetadata transformResultMetadata = transformOperator.getResultMetadata(groupByExpression);
hasMVGroupByExpression |= !transformResultMetadata.isSingleValue();
hasNoDictionaryGroupByExpression |= !transformResultMetadata.hasDictionary();
}
_hasMVGroupByExpression = hasMVGroupByExpression;
- _hasNoDictionaryGroupByExpression = hasNoDictionaryGroupByExpression;
// Initialize group key generator
- if (_hasNoDictionaryGroupByExpression) {
- if (numGroupByExpressions == 1) {
+ if (hasNoDictionaryGroupByExpression) {
+ if (groupByExpressions.length == 1) {
_groupKeyGenerator =
new NoDictionarySingleColumnGroupKeyGenerator(transformOperator, groupByExpressions[0], numGroupsLimit);
} else {
@@ -121,9 +92,10 @@ public class DefaultGroupByExecutor implements GroupByExecutor {
// Initialize result holders
int maxNumResults = _groupKeyGenerator.getGlobalGroupKeyUpperBound();
int initialCapacity = Math.min(maxNumResults, maxInitialResultHolderCapacity);
- _resultHolders = new GroupByResultHolder[_numFunctions];
- for (int i = 0; i < _numFunctions; i++) {
- _resultHolders[i] = _functions[i].createGroupByResultHolder(initialCapacity, maxNumResults);
+ int numAggregationFunctions = aggregationFunctions.length;
+ _groupByResultHolders = new GroupByResultHolder[numAggregationFunctions];
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ _groupByResultHolders[i] = _aggregationFunctions[i].createGroupByResultHolder(initialCapacity, maxNumResults);
}
// Initialize map from document Id to group key
@@ -137,7 +109,7 @@ public class DefaultGroupByExecutor implements GroupByExecutor {
}
@Override
- public void process(@Nonnull TransformBlock transformBlock) {
+ public void process(TransformBlock transformBlock) {
// Generate group keys
// NOTE: groupKeyGenerator will limit the number of groups. Once reaching limit, no new group will be generated
if (_hasMVGroupByExpression) {
@@ -146,41 +118,31 @@ public class DefaultGroupByExecutor implements GroupByExecutor {
_groupKeyGenerator.generateKeysForBlock(transformBlock, _svGroupKeys);
}
- int length = transformBlock.getNumDocs();
int capacityNeeded = _groupKeyGenerator.getCurrentGroupKeyUpperBound();
- for (int i = 0; i < _numFunctions; i++) {
- GroupByResultHolder resultHolder = _resultHolders[i];
- resultHolder.ensureCapacity(capacityNeeded);
+ int length = transformBlock.getNumDocs();
+ int numAggregationFunctions = _aggregationFunctions.length;
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ GroupByResultHolder groupByResultHolder = _groupByResultHolders[i];
+ groupByResultHolder.ensureCapacity(capacityNeeded);
aggregate(transformBlock, length, i);
}
}
- protected void aggregate(@Nonnull TransformBlock transformBlock, int length, int functionIndex) {
- AggregationFunction function = _functions[functionIndex];
- GroupByResultHolder resultHolder = _resultHolders[functionIndex];
+ protected void aggregate(TransformBlock transformBlock, int length, int functionIndex) {
+ AggregationFunction aggregationFunction = _aggregationFunctions[functionIndex];
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap =
+ AggregationFunctionUtils.getBlockValSetMap(aggregationFunction, transformBlock);
- if (function.getType() == AggregationFunctionType.COUNT) {
- if (_hasMVGroupByExpression) {
- function.aggregateGroupByMV(length, _mvGroupKeys, resultHolder, Collections.emptyMap());
- } else {
- function.aggregateGroupBySV(length, _svGroupKeys, resultHolder, Collections.emptyMap());
- }
+ GroupByResultHolder groupByResultHolder = _groupByResultHolders[functionIndex];
+ if (_hasMVGroupByExpression) {
+ aggregationFunction.aggregateGroupByMV(length, _mvGroupKeys, groupByResultHolder, blockValSetMap);
} else {
- Map<String, BlockValSet> blockValSetMap = new HashMap<>();
- for (int i = 0; i < _aggregationExpressions[functionIndex].length; i++) {
- TransformExpressionTree aggregationExpression = _aggregationExpressions[functionIndex][i];
- blockValSetMap.put(aggregationExpression.toString(), transformBlock.getBlockValueSet(aggregationExpression));
- }
- if (_hasMVGroupByExpression) {
- function.aggregateGroupByMV(length, _mvGroupKeys, resultHolder, blockValSetMap);
- } else {
- function.aggregateGroupBySV(length, _svGroupKeys, resultHolder, blockValSetMap);
- }
+ aggregationFunction.aggregateGroupBySV(length, _svGroupKeys, groupByResultHolder, blockValSetMap);
}
}
@Override
public AggregationGroupByResult getResult() {
- return new AggregationGroupByResult(_groupKeyGenerator, _functions, _resultHolders);
+ return new AggregationGroupByResult(_groupKeyGenerator, _aggregationFunctions, _groupByResultHolders);
}
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java
index 75ef1a1..c1266c5 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/groupby/GroupByExecutor.java
@@ -18,7 +18,6 @@
*/
package org.apache.pinot.core.query.aggregation.groupby;
-import javax.annotation.Nonnull;
import org.apache.pinot.core.operator.blocks.TransformBlock;
@@ -32,7 +31,7 @@ public interface GroupByExecutor {
*
* @param transformBlock Transform block
*/
- void process(@Nonnull TransformBlock transformBlock);
+ void process(TransformBlock transformBlock);
/**
* Returns the result of group-by aggregation.
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
index 3bde93a..825ebf2 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java
@@ -20,19 +20,16 @@ package org.apache.pinot.core.query.reduce;
import java.io.Serializable;
import java.util.ArrayList;
-import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.pinot.common.metrics.BrokerMetrics;
-import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.BrokerRequest;
import org.apache.pinot.common.response.broker.AggregationResult;
import org.apache.pinot.common.response.broker.BrokerResponseNative;
import org.apache.pinot.common.response.broker.ResultTable;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.DataTable;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.transport.ServerRoutingInstance;
@@ -44,16 +41,12 @@ import org.apache.pinot.core.util.QueryOptions;
*/
public class AggregationDataTableReducer implements DataTableReducer {
private final AggregationFunction[] _aggregationFunctions;
- private final List<AggregationInfo> _aggregationInfos;
- private final int _numAggregationFunctions;
private final boolean _preserveType;
private final boolean _responseFormatSql;
AggregationDataTableReducer(BrokerRequest brokerRequest, AggregationFunction[] aggregationFunctions,
QueryOptions queryOptions) {
_aggregationFunctions = aggregationFunctions;
- _aggregationInfos = brokerRequest.getAggregationsInfo();
- _numAggregationFunctions = aggregationFunctions.length;
_preserveType = queryOptions.isPreserveType();
_responseFormatSql = queryOptions.isResponseFormatSQL();
}
@@ -67,7 +60,6 @@ public class AggregationDataTableReducer implements DataTableReducer {
public void reduceAndSetResults(String tableName, DataSchema dataSchema,
Map<ServerRoutingInstance, DataTable> dataTableMap, BrokerResponseNative brokerResponseNative,
BrokerMetrics brokerMetrics) {
-
if (dataTableMap.isEmpty()) {
if (_responseFormatSql) {
DataSchema finalDataSchema = getResultTableDataSchema();
@@ -76,14 +68,11 @@ public class AggregationDataTableReducer implements DataTableReducer {
return;
}
- assert dataSchema != null;
-
- Collection<DataTable> dataTables = dataTableMap.values();
-
- // Merge results from all data tables.
- Object[] intermediateResults = new Object[_numAggregationFunctions];
- for (DataTable dataTable : dataTables) {
- for (int i = 0; i < _numAggregationFunctions; i++) {
+ // Merge results from all data tables
+ int numAggregationFunctions = _aggregationFunctions.length;
+ Object[] intermediateResults = new Object[numAggregationFunctions];
+ for (DataTable dataTable : dataTableMap.values()) {
+ for (int i = 0; i < numAggregationFunctions; i++) {
Object intermediateResultToMerge;
DataSchema.ColumnDataType columnDataType = dataSchema.getColumnDataType(i);
switch (columnDataType) {
@@ -120,8 +109,9 @@ public class AggregationDataTableReducer implements DataTableReducer {
*/
private ResultTable reduceToResultTable(Object[] intermediateResults) {
List<Object[]> rows = new ArrayList<>(1);
- Object[] row = new Object[_numAggregationFunctions];
- for (int i = 0; i < _numAggregationFunctions; i++) {
+ int numAggregationFunctions = _aggregationFunctions.length;
+ Object[] row = new Object[numAggregationFunctions];
+ for (int i = 0; i < numAggregationFunctions; i++) {
row[i] = _aggregationFunctions[i].extractFinalResult(intermediateResults[i]);
}
rows.add(row);
@@ -134,9 +124,10 @@ public class AggregationDataTableReducer implements DataTableReducer {
* Sets aggregation results into AggregationResults
*/
private List<AggregationResult> reduceToAggregationResult(Object[] intermediateResults, DataSchema dataSchema) {
- // Extract final results and set them into the broker response.
- List<AggregationResult> reducedAggregationResults = new ArrayList<>(_numAggregationFunctions);
- for (int i = 0; i < _numAggregationFunctions; i++) {
+ // Extract final results and set them into the broker response
+ int numAggregationFunctions = _aggregationFunctions.length;
+ List<AggregationResult> reducedAggregationResults = new ArrayList<>(numAggregationFunctions);
+ for (int i = 0; i < numAggregationFunctions; i++) {
Serializable resultValue = AggregationFunctionUtils
.getSerializableValue(_aggregationFunctions[i].extractFinalResult(intermediateResults[i]));
@@ -153,13 +144,13 @@ public class AggregationDataTableReducer implements DataTableReducer {
* Constructs the data schema for the final results table
*/
private DataSchema getResultTableDataSchema() {
- String[] finalColumnNames = new String[_numAggregationFunctions];
- DataSchema.ColumnDataType[] finalColumnDataTypes = new DataSchema.ColumnDataType[_numAggregationFunctions];
- for (int i = 0; i < _numAggregationFunctions; i++) {
- AggregationFunctionContext aggregationFunctionContext =
- AggregationFunctionUtils.getAggregationFunctionContext(_aggregationInfos.get(i));
- finalColumnNames[i] = aggregationFunctionContext.getResultColumnName();
- finalColumnDataTypes[i] = _aggregationFunctions[i].getFinalResultColumnType();
+ int numAggregationFunctions = _aggregationFunctions.length;
+ String[] finalColumnNames = new String[numAggregationFunctions];
+ DataSchema.ColumnDataType[] finalColumnDataTypes = new DataSchema.ColumnDataType[numAggregationFunctions];
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ AggregationFunction aggregationFunction = _aggregationFunctions[i];
+ finalColumnNames[i] = aggregationFunction.getResultColumnName();
+ finalColumnDataTypes[i] = aggregationFunction.getFinalResultColumnType();
}
return new DataSchema(finalColumnNames, finalColumnDataTypes);
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/CombineService.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/CombineService.java
index 79682e9..0e265e8 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/CombineService.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/CombineService.java
@@ -27,7 +27,7 @@ import org.apache.pinot.common.request.Selection;
import org.apache.pinot.common.response.ProcessingException;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.selection.SelectionOperatorUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -68,19 +68,19 @@ public class CombineService {
return;
}
- AggregationFunctionContext[] mergedAggregationFunctionContexts = mergedBlock.getAggregationFunctionContexts();
- if (mergedAggregationFunctionContexts == null) {
+ AggregationFunction[] mergedAggregationFunctions = mergedBlock.getAggregationFunctions();
+ if (mergedAggregationFunctions == null) {
// No data in merged block.
- mergedBlock.setAggregationFunctionContexts(blockToMerge.getAggregationFunctionContexts());
+ mergedBlock.setAggregationFunctions(blockToMerge.getAggregationFunctions());
mergedBlock.setAggregationResults(aggregationResultToMerge);
- }
-
- // Merge two block.
- List<Object> mergedAggregationResult = mergedBlock.getAggregationResult();
- int numAggregationFunctions = mergedAggregationFunctionContexts.length;
- for (int i = 0; i < numAggregationFunctions; i++) {
- mergedAggregationResult.set(i, mergedAggregationFunctionContexts[i].getAggregationFunction()
- .merge(mergedAggregationResult.get(i), aggregationResultToMerge.get(i)));
+ } else {
+ // Merge two blocks.
+ List<Object> mergedAggregationResult = mergedBlock.getAggregationResult();
+ int numAggregationFunctions = mergedAggregationFunctions.length;
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ mergedAggregationResult.set(i,
+ mergedAggregationFunctions[i].merge(mergedAggregationResult.get(i), aggregationResultToMerge.get(i)));
+ }
}
} else {
// Combine aggregation group-by result, which should not come into CombineService.
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ComparisonFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ComparisonFunction.java
index ec96f0e..9841c10 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ComparisonFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ComparisonFunction.java
@@ -19,7 +19,7 @@
package org.apache.pinot.core.query.reduce;
import org.apache.pinot.common.request.AggregationInfo;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory;
//This class will be inherited by different classes that compare (e.g., for equality) the input value by the base value
@@ -27,8 +27,7 @@ public abstract class ComparisonFunction {
private final String _functionExpression;
protected ComparisonFunction(AggregationInfo aggregationInfo) {
- _functionExpression =
- AggregationFunctionUtils.getAggregationFunctionContext(aggregationInfo).getAggregationColumnName();
+ _functionExpression = AggregationFunctionFactory.getAggregationFunction(aggregationInfo, null).getColumnName();
}
public abstract boolean isComparisonValid(String aggResult);
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/DistinctDataTableReducer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/DistinctDataTableReducer.java
index de5db1c..9c5f222 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/DistinctDataTableReducer.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/DistinctDataTableReducer.java
@@ -76,8 +76,7 @@ public class DistinctDataTableReducer implements DataTableReducer {
DataSchema finalDataSchema = getEmptyResultTableDataSchema();
brokerResponseNative.setResultTable(new ResultTable(finalDataSchema, Collections.emptyList()));
} else {
- brokerResponseNative
- .setSelectionResults(new SelectionResults(getDistinctColumns(), Collections.emptyList()));
+ brokerResponseNative.setSelectionResults(new SelectionResults(getDistinctColumns(), Collections.emptyList()));
}
return;
}
@@ -163,13 +162,15 @@ public class DistinctDataTableReducer implements DataTableReducer {
}
private List<String> getDistinctColumns() {
- return AggregationFunctionUtils.getAggregationExpressions(_brokerRequest.getAggregationsInfo().get(0));
+ return AggregationFunctionUtils.getArguments(_brokerRequest.getAggregationsInfo().get(0));
}
private DataSchema getEmptyResultTableDataSchema() {
- String[] columns = getDistinctColumns().toArray(new String[0]);
- DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[columns.length];
+ List<String> distinctColumns = getDistinctColumns();
+ int numColumns = distinctColumns.size();
+ String[] columnNames = distinctColumns.toArray(new String[numColumns]);
+ DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[numColumns];
Arrays.fill(columnDataTypes, DataSchema.ColumnDataType.STRING);
- return new DataSchema(columns, columnDataTypes);
+ return new DataSchema(columnNames, columnDataTypes);
}
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
index 97c2019..98e6a01 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java
@@ -31,7 +31,6 @@ import java.util.function.BiFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.pinot.common.metrics.BrokerMeter;
import org.apache.pinot.common.metrics.BrokerMetrics;
-import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.BrokerRequest;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.GroupBy;
@@ -48,7 +47,6 @@ import org.apache.pinot.common.utils.request.RequestUtils;
import org.apache.pinot.core.data.table.ConcurrentIndexedTable;
import org.apache.pinot.core.data.table.IndexedTable;
import org.apache.pinot.core.data.table.Record;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByTrimmingService;
@@ -63,8 +61,6 @@ import org.apache.pinot.core.util.QueryOptions;
public class GroupByDataTableReducer implements DataTableReducer {
private final BrokerRequest _brokerRequest;
private final AggregationFunction[] _aggregationFunctions;
- private final List<AggregationInfo> _aggregationInfos;
- private final AggregationFunctionContext[] _aggregationFunctionContexts;
private final List<SelectionSort> _orderBy;
private final GroupBy _groupBy;
private final int _numAggregationFunctions;
@@ -80,8 +76,6 @@ public class GroupByDataTableReducer implements DataTableReducer {
QueryOptions queryOptions) {
_brokerRequest = brokerRequest;
_aggregationFunctions = aggregationFunctions;
- _aggregationInfos = brokerRequest.getAggregationsInfo();
- _aggregationFunctionContexts = AggregationFunctionUtils.getAggregationFunctionContexts(_brokerRequest);
_numAggregationFunctions = aggregationFunctions.length;
_groupBy = brokerRequest.getGroupBy();
_numGroupBy = _groupBy.getExpressionsSize();
@@ -154,7 +148,7 @@ public class GroupByDataTableReducer implements DataTableReducer {
// This is the primary PQL compliant group by
boolean[] aggregationFunctionSelectStatus =
- AggregationFunctionUtils.getAggregationFunctionsSelectStatus(_aggregationInfos);
+ AggregationFunctionUtils.getAggregationFunctionsSelectStatus(_brokerRequest.getAggregationsInfo());
setGroupByHavingResults(brokerResponseNative, aggregationFunctionSelectStatus, dataTables,
_brokerRequest.getHavingFilterQuery(), _brokerRequest.getHavingFilterSubQueryMap());
@@ -303,10 +297,9 @@ public class GroupByDataTableReducer implements DataTableReducer {
}
private IndexedTable getIndexedTable(DataSchema dataSchema, Collection<DataTable> dataTables) {
-
int indexedTableCapacity = GroupByUtils.getTableCapacity(_groupBy, _orderBy);
IndexedTable indexedTable =
- new ConcurrentIndexedTable(dataSchema, _aggregationInfos, _orderBy, indexedTableCapacity);
+ new ConcurrentIndexedTable(dataSchema, _aggregationFunctions, _orderBy, indexedTableCapacity);
for (DataTable dataTable : dataTables) {
BiFunction[] functions = new BiFunction[_numColumns];
@@ -552,8 +545,9 @@ public class GroupByDataTableReducer implements DataTableReducer {
if (aggregationFunctionsSelectStatus[i]) {
finalColumnNames[count] = columnNames[i];
finalOutResultMaps[count] = finalResultMaps[i];
- finalResultTableAggNames[count] = _aggregationFunctionContexts[i].getResultColumnName();
- finalAggregationFunctions[count] = _aggregationFunctions[i];
+ AggregationFunction aggregationFunction = _aggregationFunctions[i];
+ finalResultTableAggNames[count] = aggregationFunction.getResultColumnName();
+ finalAggregationFunctions[count] = aggregationFunction;
count++;
}
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/request/ServerQueryRequest.java b/pinot-core/src/main/java/org/apache/pinot/core/query/request/ServerQueryRequest.java
index 3865f2e..bf2fa07 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/request/ServerQueryRequest.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/request/ServerQueryRequest.java
@@ -95,7 +95,7 @@ public class ServerQueryRequest {
_aggregationExpressions = new HashSet<>();
for (AggregationInfo aggregationInfo : aggregationsInfo) {
if (!aggregationInfo.getAggregationType().equalsIgnoreCase(AggregationFunctionType.COUNT.getName())) {
- for (String expressions : AggregationFunctionUtils.getAggregationExpressions(aggregationInfo)) {
+ for (String expressions : AggregationFunctionUtils.getArguments(aggregationInfo)) {
_aggregationExpressions.add(TransformExpressionTree.compileToExpressionTree(expressions));
}
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/startree/StarTreeUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/startree/StarTreeUtils.java
index f8e9bd1..3810826 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/startree/StarTreeUtils.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/startree/StarTreeUtils.java
@@ -18,7 +18,6 @@
*/
package org.apache.pinot.core.startree;
-import java.util.Arrays;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
@@ -27,8 +26,6 @@ import org.apache.pinot.common.request.BrokerRequest;
import org.apache.pinot.common.request.FilterOperator;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.common.utils.request.FilterQueryTree;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.startree.v2.AggregationFunctionColumnPair;
import org.apache.pinot.core.startree.v2.StarTreeV2Metadata;
@@ -57,8 +54,8 @@ public class StarTreeUtils {
* </ul>
*/
public static boolean isFitForStarTree(StarTreeV2Metadata starTreeV2Metadata,
- Set<AggregationFunctionColumnPair> aggregationFunctionColumnPairs,
- @Nullable Set<TransformExpressionTree> groupByExpressions, @Nullable FilterQueryTree rootFilterNode) {
+ AggregationFunctionColumnPair[] aggregationFunctionColumnPairs,
+ @Nullable TransformExpressionTree[] groupByExpressions, @Nullable FilterQueryTree rootFilterNode) {
// Check aggregations
for (AggregationFunctionColumnPair aggregationFunctionColumnPair : aggregationFunctionColumnPairs) {
if (!starTreeV2Metadata.containsFunctionColumnPair(aggregationFunctionColumnPair)) {
@@ -102,29 +99,4 @@ public class StarTreeUtils {
String column = filterNode.getColumn();
return starTreeDimensions.contains(column);
}
-
- /**
- * Creates a {@link AggregationFunctionContext} from the given context but replace the column with the function-column
- * pair.
- */
- public static AggregationFunctionContext createStarTreeFunctionContext(AggregationFunctionContext functionContext) {
- AggregationFunction function = functionContext.getAggregationFunction();
- AggregationFunctionColumnPair functionColumnPair =
- new AggregationFunctionColumnPair(function.getType(), functionContext.getColumnName());
- return new AggregationFunctionContext(function, Arrays.asList(functionColumnPair.toColumnName()));
- }
-
- /**
- * Creates an array of {@link AggregationFunctionContext}s from the given contexts but replace the column with the
- * function-column pair.
- */
- public static AggregationFunctionContext[] createStarTreeFunctionContexts(
- AggregationFunctionContext[] functionContexts) {
- int numContexts = functionContexts.length;
- AggregationFunctionContext[] starTreeFunctionContexts = new AggregationFunctionContext[numContexts];
- for (int i = 0; i < numContexts; i++) {
- starTreeFunctionContexts[i] = createStarTreeFunctionContext(functionContexts[i]);
- }
- return starTreeFunctionContexts;
- }
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeAggregationExecutor.java b/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeAggregationExecutor.java
index daf9b64..110efb2 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeAggregationExecutor.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeAggregationExecutor.java
@@ -18,18 +18,10 @@
*/
package org.apache.pinot.core.startree.executor;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import org.apache.pinot.common.function.AggregationFunctionType;
-import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.operator.blocks.TransformBlock;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
-import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.DefaultAggregationExecutor;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
-import org.apache.pinot.core.startree.StarTreeUtils;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.startree.v2.AggregationFunctionColumnPair;
@@ -37,50 +29,31 @@ import org.apache.pinot.core.startree.v2.AggregationFunctionColumnPair;
* The <code>StarTreeAggregationExecutor</code> class is the aggregation executor for star-tree index.
* <ul>
* <li>The column in function context is function-column pair</li>
- * <li>No UDF in aggregation</li>
+ * <li>No transform function in aggregation</li>
* <li>For <code>COUNT</code> aggregation function, we need to aggregate on the pre-aggregated column</li>
* </ul>
*/
public class StarTreeAggregationExecutor extends DefaultAggregationExecutor {
- // StarTree converts column names from min(col) to min__col, this is to store the original mapping.
- private final String[][] _functionArgs;
+ private final AggregationFunctionColumnPair[] _aggregationFunctionColumnPairs;
- public StarTreeAggregationExecutor(AggregationFunctionContext[] functionContexts) {
- super(StarTreeUtils.createStarTreeFunctionContexts(functionContexts));
+ public StarTreeAggregationExecutor(AggregationFunction[] aggregationFunctions) {
+ super(aggregationFunctions);
- _functionArgs = new String[functionContexts.length][];
- for (int i = 0; i < functionContexts.length; i++) {
- List<String> expressions = functionContexts[i].getExpressions();
- _functionArgs[i] = new String[expressions.size()];
-
- for (int j = 0; j < expressions.size(); j++) {
- _functionArgs[i][j] = expressions.get(j);
- }
+ int numAggregationFunctions = aggregationFunctions.length;
+ _aggregationFunctionColumnPairs = new AggregationFunctionColumnPair[numAggregationFunctions];
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ _aggregationFunctionColumnPairs[i] =
+ AggregationFunctionUtils.getAggregationFunctionColumnPair(aggregationFunctions[i]);
}
}
@Override
public void aggregate(TransformBlock transformBlock) {
+ int numAggregationFunctions = _aggregationFunctions.length;
int length = transformBlock.getNumDocs();
- for (int i = 0; i < _numFunctions; i++) {
- AggregationFunction function = _functions[i];
- AggregationResultHolder resultHolder = _resultHolders[i];
-
- AggregationFunctionType functionType = function.getType();
- if ((functionType == AggregationFunctionType.COUNT)) {
- BlockValSet blockValueSet =
- transformBlock.getBlockValueSet(AggregationFunctionColumnPair.COUNT_STAR_COLUMN_NAME);
- function.aggregate(length, resultHolder, Collections.singletonMap(_functionArgs[i][0], blockValueSet));
- } else {
-
- Map<String, BlockValSet> blockValSetMap = new HashMap<>();
- for (int j = 0; j < _functionArgs[i].length; j++) {
- blockValSetMap.put(_functionArgs[i][j], transformBlock.getBlockValueSet(
- AggregationFunctionColumnPair.toColumnName(functionType, _expressions[i][j].getValue())));
- }
-
- function.aggregate(length, resultHolder, blockValSetMap);
- }
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ _aggregationFunctions[i].aggregate(length, _aggregationResultHolders[i],
+ AggregationFunctionUtils.getBlockValSetMap(_aggregationFunctionColumnPairs[i], transformBlock));
}
}
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeGroupByExecutor.java b/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeGroupByExecutor.java
index cd0518a..fef887c 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeGroupByExecutor.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/startree/executor/StarTreeGroupByExecutor.java
@@ -18,22 +18,15 @@
*/
package org.apache.pinot.core.startree.executor;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
import java.util.Map;
-import javax.annotation.Nonnull;
-import org.apache.pinot.common.function.AggregationFunctionType;
-import org.apache.pinot.common.request.GroupBy;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.operator.blocks.TransformBlock;
import org.apache.pinot.core.operator.transform.TransformOperator;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.query.aggregation.groupby.DefaultGroupByExecutor;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
-import org.apache.pinot.core.startree.StarTreeUtils;
import org.apache.pinot.core.startree.v2.AggregationFunctionColumnPair;
@@ -41,59 +34,36 @@ import org.apache.pinot.core.startree.v2.AggregationFunctionColumnPair;
* The <code>StarTreeGroupByExecutor</code> class is the group-by executor for star-tree index.
* <ul>
* <li>The column in function context is function-column pair</li>
- * <li>No UDF in aggregation</li>
+ * <li>No transform function in aggregation</li>
* <li>For <code>COUNT</code> aggregation function, we need to aggregate on the pre-aggregated column</li>
* </ul>
*/
public class StarTreeGroupByExecutor extends DefaultGroupByExecutor {
+ private final AggregationFunctionColumnPair[] _aggregationFunctionColumnPairs;
- // StarTree converts column names from min(col) to min__col, this is to store the original mapping.
- private final String[][] _functionArgs;
+ public StarTreeGroupByExecutor(AggregationFunction[] aggregationFunctions,
+ TransformExpressionTree[] groupByExpressions, int maxInitialResultHolderCapacity, int numGroupsLimit,
+ TransformOperator transformOperator) {
+ super(aggregationFunctions, groupByExpressions, maxInitialResultHolderCapacity, numGroupsLimit, transformOperator);
- public StarTreeGroupByExecutor(@Nonnull AggregationFunctionContext[] functionContexts, @Nonnull GroupBy groupBy,
- int maxInitialResultHolderCapacity, int numGroupsLimit, @Nonnull TransformOperator transformOperator) {
- super(StarTreeUtils.createStarTreeFunctionContexts(functionContexts), groupBy, maxInitialResultHolderCapacity,
- numGroupsLimit, transformOperator);
-
- _functionArgs = new String[functionContexts.length][];
- for (int i = 0; i < functionContexts.length; i++) {
- List<String> expressions = functionContexts[i].getExpressions();
- _functionArgs[i] = new String[expressions.size()];
-
- for (int j = 0; j < expressions.size(); j++) {
- _functionArgs[i][j] = expressions.get(j);
- }
+ int numAggregationFunctions = aggregationFunctions.length;
+ _aggregationFunctionColumnPairs = new AggregationFunctionColumnPair[numAggregationFunctions];
+ for (int i = 0; i < numAggregationFunctions; i++) {
+ _aggregationFunctionColumnPairs[i] =
+ AggregationFunctionUtils.getAggregationFunctionColumnPair(aggregationFunctions[i]);
}
}
@Override
- protected void aggregate(@Nonnull TransformBlock transformBlock, int length, int functionIndex) {
- AggregationFunction function = _functions[functionIndex];
- GroupByResultHolder resultHolder = _resultHolders[functionIndex];
-
- BlockValSet blockValueSet;
- Map<String, BlockValSet> blockValSetMap;
-
- AggregationFunctionType functionType = function.getType();
- if (functionType == AggregationFunctionType.COUNT) {
- blockValueSet = transformBlock.getBlockValueSet(AggregationFunctionColumnPair.COUNT_STAR_COLUMN_NAME);
- function.getInputExpressions();
- blockValSetMap = Collections.singletonMap(_functionArgs[functionIndex][0], blockValueSet);
- } else {
-
- blockValSetMap = new HashMap<>();
- for (int i = 0; i < _functionArgs[functionIndex].length; i++) {
- TransformExpressionTree aggregationExpression = _aggregationExpressions[functionIndex][i];
- blockValueSet = transformBlock.getBlockValueSet(
- AggregationFunctionColumnPair.toColumnName(functionType, aggregationExpression.getValue()));
- blockValSetMap.put(_functionArgs[functionIndex][i], blockValueSet);
- }
- }
-
+ protected void aggregate(TransformBlock transformBlock, int length, int functionIndex) {
+ AggregationFunction aggregationFunction = _aggregationFunctions[functionIndex];
+ GroupByResultHolder groupByResultHolder = _groupByResultHolders[functionIndex];
+ Map<TransformExpressionTree, BlockValSet> blockValSetMap =
+ AggregationFunctionUtils.getBlockValSetMap(_aggregationFunctionColumnPairs[functionIndex], transformBlock);
if (_hasMVGroupByExpression) {
- function.aggregateGroupByMV(length, _mvGroupKeys, resultHolder, blockValSetMap);
+ aggregationFunction.aggregateGroupByMV(length, _mvGroupKeys, groupByResultHolder, blockValSetMap);
} else {
- function.aggregateGroupBySV(length, _svGroupKeys, resultHolder, blockValSetMap);
+ aggregationFunction.aggregateGroupBySV(length, _svGroupKeys, groupByResultHolder, blockValSetMap);
}
}
-}
\ No newline at end of file
+}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/startree/plan/StarTreeTransformPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/startree/plan/StarTreeTransformPlanNode.java
index 10c2276..03938d1 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/startree/plan/StarTreeTransformPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/startree/plan/StarTreeTransformPlanNode.java
@@ -18,6 +18,7 @@
*/
package org.apache.pinot.core.startree.plan;
+import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
@@ -40,8 +41,8 @@ public class StarTreeTransformPlanNode implements PlanNode {
private final StarTreeProjectionPlanNode _starTreeProjectionPlanNode;
public StarTreeTransformPlanNode(StarTreeV2 starTreeV2,
- Set<AggregationFunctionColumnPair> aggregationFunctionColumnPairs,
- @Nullable Set<TransformExpressionTree> groupByExpressions, @Nullable FilterQueryTree rootFilterNode,
+ AggregationFunctionColumnPair[] aggregationFunctionColumnPairs,
+ @Nullable TransformExpressionTree[] groupByExpressions, @Nullable FilterQueryTree rootFilterNode,
@Nullable Map<String, String> debugOptions) {
Set<String> projectionColumns = new HashSet<>();
for (AggregationFunctionColumnPair aggregationFunctionColumnPair : aggregationFunctionColumnPairs) {
@@ -49,7 +50,7 @@ public class StarTreeTransformPlanNode implements PlanNode {
}
Set<String> groupByColumns;
if (groupByExpressions != null) {
- _groupByExpressions = groupByExpressions;
+ _groupByExpressions = new HashSet<>(Arrays.asList(groupByExpressions));
groupByColumns = new HashSet<>();
for (TransformExpressionTree groupByExpression : groupByExpressions) {
groupByExpression.getColumns(groupByColumns);
@@ -65,6 +66,9 @@ public class StarTreeTransformPlanNode implements PlanNode {
@Override
public TransformOperator run() {
+ // NOTE: Here we do not put aggregation expressions into TransformOperator based on the following assumptions:
+ // - They are all columns (not functions or constants), where no transform is required
+ // - We never call TransformOperator.getResultMetadata() or TransformOperator.getDictionary() on them
return new TransformOperator(_starTreeProjectionPlanNode.run(), _groupByExpressions);
}
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/table/IndexedTableTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/table/IndexedTableTest.java
index 6c90739..960791a 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/data/table/IndexedTableTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/data/table/IndexedTableTest.java
@@ -20,6 +20,7 @@ package org.apache.pinot.core.data.table;
import com.google.common.collect.Lists;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
@@ -29,10 +30,12 @@ import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
-import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.DataSchema.ColumnDataType;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.MaxAggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.SumAggregationFunction;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
@@ -44,31 +47,18 @@ import org.testng.annotations.Test;
public class IndexedTableTest {
@Test
- public void testConcurrentIndexedTable() throws InterruptedException, TimeoutException, ExecutionException {
-
+ public void testConcurrentIndexedTable()
+ throws InterruptedException, TimeoutException, ExecutionException {
DataSchema dataSchema = new DataSchema(new String[]{"d1", "d2", "d3", "sum(m1)", "max(m2)"},
- new ColumnDataType[]{ColumnDataType.STRING, ColumnDataType.INT, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE,
- ColumnDataType.DOUBLE});
-
- AggregationInfo agg1 = new AggregationInfo();
- List<String> args1 = new ArrayList<>(1);
- args1.add("m1");
- agg1.setExpressions(args1);
- agg1.setAggregationType("sum");
-
- AggregationInfo agg2 = new AggregationInfo();
- List<String> args2 = new ArrayList<>(1);
- args2.add("m2");
- agg2.setExpressions(args2);
- agg2.setAggregationType("max");
- List<AggregationInfo> aggregationInfos = Lists.newArrayList(agg1, agg2);
+ new ColumnDataType[]{ColumnDataType.STRING, ColumnDataType.INT, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE});
+ AggregationFunction[] aggregationFunctions =
+ new AggregationFunction[]{new SumAggregationFunction("m1"), new MaxAggregationFunction("m2")};
+ SelectionSort selectionSort = new SelectionSort();
+ selectionSort.setColumn("sum(m1)");
+ selectionSort.setIsAsc(true);
+ List<SelectionSort> orderBy = Collections.singletonList(selectionSort);
- SelectionSort sel = new SelectionSort();
- sel.setColumn("sum(m1)");
- sel.setIsAsc(true);
- List<SelectionSort> orderBy = Lists.newArrayList(sel);
-
- IndexedTable indexedTable = new ConcurrentIndexedTable(dataSchema, aggregationInfos, orderBy, 5);
+ IndexedTable indexedTable = new ConcurrentIndexedTable(dataSchema, aggregationFunctions, orderBy, 5);
// 3 threads upsert together
// a inserted 6 times (60), b inserted 5 times (50), d inserted 2 times (20)
@@ -81,7 +71,8 @@ public class IndexedTableTest {
Callable<Void> c1 = () -> {
indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d, 10d, 100d}));
indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d, 10d, 200d}));
- indexedTable.upsert(getKey(new Object[]{"c", 3, 30d}), getRecord(new Object[]{"c", 3, 30d, 10000d, 300d})); // eviction candidate
+ indexedTable.upsert(getKey(new Object[]{"c", 3, 30d}),
+ getRecord(new Object[]{"c", 3, 30d, 10000d, 300d})); // eviction candidate
indexedTable.upsert(getKey(new Object[]{"d", 4, 40d}), getRecord(new Object[]{"d", 4, 40d, 10d, 400d}));
indexedTable.upsert(getKey(new Object[]{"d", 4, 40d}), getRecord(new Object[]{"d", 4, 40d, 10d, 400d}));
indexedTable.upsert(getKey(new Object[]{"e", 5, 50d}), getRecord(new Object[]{"e", 5, 50d, 10d, 500d}));
@@ -90,27 +81,29 @@ public class IndexedTableTest {
Callable<Void> c2 = () -> {
indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d, 10d, 100d}));
- indexedTable.upsert(getKey(new Object[]{"f", 6, 60d}), getRecord(new Object[]{"f", 6, 60d,20000d, 600d})); // eviction candidate
- indexedTable.upsert(getKey(new Object[]{"g", 7, 70d}), getRecord(new Object[]{"g", 7, 70d,10d, 700d}));
- indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d,10d, 200d}));
- indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d,10d, 200d}));
- indexedTable.upsert(getKey(new Object[]{"h", 8, 80d}), getRecord(new Object[]{"h", 8, 80d,10d, 800d}));
- indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d,10d, 100d}));
- indexedTable.upsert(getKey(new Object[]{"i", 9, 90d}), getRecord(new Object[]{"i", 9, 90d,500d, 900d}));
+ indexedTable.upsert(getKey(new Object[]{"f", 6, 60d}),
+ getRecord(new Object[]{"f", 6, 60d, 20000d, 600d})); // eviction candidate
+ indexedTable.upsert(getKey(new Object[]{"g", 7, 70d}), getRecord(new Object[]{"g", 7, 70d, 10d, 700d}));
+ indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d, 10d, 200d}));
+ indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d, 10d, 200d}));
+ indexedTable.upsert(getKey(new Object[]{"h", 8, 80d}), getRecord(new Object[]{"h", 8, 80d, 10d, 800d}));
+ indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d, 10d, 100d}));
+ indexedTable.upsert(getKey(new Object[]{"i", 9, 90d}), getRecord(new Object[]{"i", 9, 90d, 500d, 900d}));
return null;
};
Callable<Void> c3 = () -> {
- indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d,10d, 100d}));
- indexedTable.upsert(getKey(new Object[]{"j", 10, 100d}), getRecord(new Object[]{"j", 10, 100d,10d, 1000d}));
- indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d,10d, 200d}));
- indexedTable.upsert(getKey(new Object[]{"k", 11, 110d}), getRecord(new Object[]{"k", 11, 110d,10d, 1100d}));
- indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d,10d, 100d}));
- indexedTable.upsert(getKey(new Object[]{"l", 12, 120d}), getRecord(new Object[]{"l", 12, 120d,10d, 1200d}));
- indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d,10d, 100d})); // trimming candidate
- indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d,10d, 200d}));
- indexedTable.upsert(getKey(new Object[]{"m", 13, 130d}), getRecord(new Object[]{"m", 13, 130d,10d, 1300d}));
- indexedTable.upsert(getKey(new Object[]{"n", 14, 140d}), getRecord(new Object[]{"n", 14, 140d,10d, 1400d}));
+ indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d, 10d, 100d}));
+ indexedTable.upsert(getKey(new Object[]{"j", 10, 100d}), getRecord(new Object[]{"j", 10, 100d, 10d, 1000d}));
+ indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d, 10d, 200d}));
+ indexedTable.upsert(getKey(new Object[]{"k", 11, 110d}), getRecord(new Object[]{"k", 11, 110d, 10d, 1100d}));
+ indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}), getRecord(new Object[]{"a", 1, 10d, 10d, 100d}));
+ indexedTable.upsert(getKey(new Object[]{"l", 12, 120d}), getRecord(new Object[]{"l", 12, 120d, 10d, 1200d}));
+ indexedTable.upsert(getKey(new Object[]{"a", 1, 10d}),
+ getRecord(new Object[]{"a", 1, 10d, 10d, 100d})); // trimming candidate
+ indexedTable.upsert(getKey(new Object[]{"b", 2, 20d}), getRecord(new Object[]{"b", 2, 20d, 10d, 200d}));
+ indexedTable.upsert(getKey(new Object[]{"m", 13, 130d}), getRecord(new Object[]{"m", 13, 130d, 10d, 1300d}));
+ indexedTable.upsert(getKey(new Object[]{"n", 14, 140d}), getRecord(new Object[]{"n", 14, 140d, 10d, 1400d}));
return null;
};
@@ -122,7 +115,6 @@ public class IndexedTableTest {
indexedTable.finish(false);
Assert.assertEquals(indexedTable.size(), 5);
checkEvicted(indexedTable, "c", "f");
-
} finally {
executorService.shutdown();
}
@@ -130,27 +122,15 @@ public class IndexedTableTest {
@Test(dataProvider = "initDataProvider")
public void testNonConcurrentIndexedTable(List<SelectionSort> orderBy, List<String> survivors) {
-
DataSchema dataSchema = new DataSchema(new String[]{"d1", "d2", "d3", "d4", "sum(m1)", "max(m2)"},
new ColumnDataType[]{ColumnDataType.STRING, ColumnDataType.INT, ColumnDataType.DOUBLE, ColumnDataType.INT, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE});
-
- AggregationInfo agg1 = new AggregationInfo();
- List<String> args1 = new ArrayList<>(1);
- args1.add("m1");
- agg1.setExpressions(args1);
- agg1.setAggregationType("sum");
-
- AggregationInfo agg2 = new AggregationInfo();
- List<String> args2 = new ArrayList<>(1);
- args2.add("m2");
- agg2.setExpressions(args2);
- agg2.setAggregationType("max");
- List<AggregationInfo> aggregationInfos = Lists.newArrayList(agg1, agg2);
+ AggregationFunction[] aggregationFunctions =
+ new AggregationFunction[]{new SumAggregationFunction("m1"), new MaxAggregationFunction("m2")};
// Test SimpleIndexedTable
- IndexedTable simpleIndexedTable = new SimpleIndexedTable(dataSchema, aggregationInfos, orderBy, 5);
+ IndexedTable simpleIndexedTable = new SimpleIndexedTable(dataSchema, aggregationFunctions, orderBy, 5);
// merge table
- IndexedTable mergeTable = new SimpleIndexedTable(dataSchema, aggregationInfos, orderBy, 10);
+ IndexedTable mergeTable = new SimpleIndexedTable(dataSchema, aggregationFunctions, orderBy, 10);
testNonConcurrent(simpleIndexedTable, mergeTable);
// finish
@@ -158,8 +138,8 @@ public class IndexedTableTest {
checkSurvivors(simpleIndexedTable, survivors);
// Test ConcurrentIndexedTable
- IndexedTable concurrentIndexedTable = new ConcurrentIndexedTable(dataSchema, aggregationInfos, orderBy, 5);
- mergeTable = new SimpleIndexedTable(dataSchema, aggregationInfos, orderBy, 10);
+ IndexedTable concurrentIndexedTable = new ConcurrentIndexedTable(dataSchema, aggregationFunctions, orderBy, 5);
+ mergeTable = new SimpleIndexedTable(dataSchema, aggregationFunctions, orderBy, 10);
testNonConcurrent(concurrentIndexedTable, mergeTable);
// finish
@@ -300,6 +280,7 @@ public class IndexedTableTest {
private Key getKey(Object[] keys) {
return new Key(keys);
}
+
private Record getRecord(Object[] columns) {
return new Record(columns);
}
@@ -307,26 +288,14 @@ public class IndexedTableTest {
@Test
public void testNoMoreNewRecords() {
DataSchema dataSchema = new DataSchema(new String[]{"d1", "d2", "d3", "sum(m1)", "max(m2)"},
- new ColumnDataType[]{ColumnDataType.STRING, ColumnDataType.INT, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE,
- ColumnDataType.DOUBLE});
-
- AggregationInfo agg1 = new AggregationInfo();
- List<String> args1 = new ArrayList<>(1);
- args1.add("m1");
- agg1.setExpressions(args1);
- agg1.setAggregationType("sum");
-
- AggregationInfo agg2 = new AggregationInfo();
- List<String> args2 = new ArrayList<>(1);
- args2.add("m2");
- agg2.setExpressions(args2);
- agg2.setAggregationType("max");
- List<AggregationInfo> aggregationInfos = Lists.newArrayList(agg1, agg2);
-
- IndexedTable indexedTable = new SimpleIndexedTable(dataSchema, aggregationInfos, null, 5);
+ new ColumnDataType[]{ColumnDataType.STRING, ColumnDataType.INT, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE});
+ AggregationFunction[] aggregationFunctions =
+ new AggregationFunction[]{new SumAggregationFunction("m1"), new MaxAggregationFunction("m2")};
+
+ IndexedTable indexedTable = new SimpleIndexedTable(dataSchema, aggregationFunctions, null, 5);
testNoMoreNewRecordsInTable(indexedTable);
- indexedTable = new ConcurrentIndexedTable(dataSchema, aggregationInfos, null, 5);
+ indexedTable = new ConcurrentIndexedTable(dataSchema, aggregationFunctions, null, 5);
testNoMoreNewRecordsInTable(indexedTable);
}
@@ -355,6 +324,5 @@ public class IndexedTableTest {
indexedTable.finish(false);
checkEvicted(indexedTable, "f", "g");
-
}
}
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/data/table/TableResizerTest.java b/pinot-core/src/test/java/org/apache/pinot/core/data/table/TableResizerTest.java
index 15c1e3f..61607c7 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/data/table/TableResizerTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/data/table/TableResizerTest.java
@@ -20,18 +20,20 @@ package org.apache.pinot.core.data.table;
import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
-import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
-import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
+import org.apache.pinot.core.query.aggregation.function.AvgAggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.DistinctCountAggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.MaxAggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.SumAggregationFunction;
import org.apache.pinot.core.query.aggregation.function.customobject.AvgPair;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
@@ -42,14 +44,8 @@ import org.testng.annotations.Test;
* Tests the functionality of {@link @TableResizer}
*/
public class TableResizerTest {
-
- private DataSchema dataSchema;
- private List<AggregationInfo> aggregationInfos;
- private List<SelectionSort> selectionSort;
- private SelectionSort sel1;
- private SelectionSort sel2;
- private SelectionSort sel3;
- private TableResizer _tableResizer;
+ private DataSchema _dataSchema;
+ private AggregationFunction[] _aggregationFunctions;
private int trimToSize = 3;
private Map<Key, Record> _recordsMap;
@@ -57,37 +53,11 @@ public class TableResizerTest {
private List<Key> _keys;
@BeforeClass
- public void beforeClass() {
- dataSchema = new DataSchema(new String[]{"d1", "d2", "d3", "sum(m1)", "max(m2)", "distinctcount(m3)", "avg(m4)"},
+ public void setUp() {
+ _dataSchema = new DataSchema(new String[]{"d1", "d2", "d3", "sum(m1)", "max(m2)", "distinctcount(m3)", "avg(m4)"},
new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.OBJECT, DataSchema.ColumnDataType.OBJECT});
- AggregationInfo agg1 = new AggregationInfo();
- List<String> args1 = new ArrayList<>(1);
- args1.add("m1");
- agg1.setExpressions(args1);
- agg1.setAggregationType("sum");
-
- AggregationInfo agg2 = new AggregationInfo();
- List<String> args2 = new ArrayList<>(1);
- args2.add("m2");
- agg2.setExpressions(args2);
- agg2.setAggregationType("max");
-
- AggregationInfo agg3 = new AggregationInfo();
- List<String> args3 = new ArrayList<>(1);
- args3.add("m3");
- agg3.setExpressions(args3);
- agg3.setAggregationType("distinctcount");
-
- AggregationInfo agg4 = new AggregationInfo();
- List<String> args4 = new ArrayList<>(1);
- args4.add("m4");
- agg4.setExpressions(args4);
- agg4.setAggregationType("avg");
- aggregationInfos = Lists.newArrayList(agg1, agg2, agg3, agg4);
-
- sel1 = new SelectionSort();
- sel2 = new SelectionSort();
- sel3 = new SelectionSort();
+ _aggregationFunctions = new AggregationFunction[]{new SumAggregationFunction("m1"), new MaxAggregationFunction(
+ "m2"), new DistinctCountAggregationFunction("m3"), new AvgAggregationFunction("m4")};
IntOpenHashSet i1 = new IntOpenHashSet();
i1.add(1);
@@ -134,95 +104,93 @@ public class TableResizerTest {
@Test
public void testResizeRecordsMap() {
Map<Key, Record> recordsMap;
+ SelectionSort selectionSort1 = new SelectionSort();
+ SelectionSort selectionSort2 = new SelectionSort();
+ SelectionSort selectionSort3 = new SelectionSort();
// Test resize algorithm with numRecordsToEvict < trimToSize.
// TotalRecords=5; trimToSize=3; numRecordsToEvict=2
// d1 asc
- sel1.setColumn("d1");
- sel1.setIsAsc(true);
- selectionSort = Lists.newArrayList(sel1);
+ selectionSort1.setColumn("d1");
+ selectionSort1.setIsAsc(true);
recordsMap = new HashMap<>(_recordsMap);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
- _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+ TableResizer tableResizer =
+ new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
+ tableResizer.resizeRecordsMap(recordsMap, trimToSize);
Assert.assertEquals(recordsMap.size(), trimToSize);
Assert.assertTrue(recordsMap.containsKey(_keys.get(0))); // a, b, c
Assert.assertTrue(recordsMap.containsKey(_keys.get(1)));
// d1 desc
- sel1.setColumn("d1");
- sel1.setIsAsc(false);
- selectionSort = Lists.newArrayList(sel1);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("d1");
+ selectionSort1.setIsAsc(false);
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
recordsMap = new HashMap<>(_recordsMap);
- _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+ tableResizer.resizeRecordsMap(recordsMap, trimToSize);
Assert.assertEquals(recordsMap.size(), trimToSize);
Assert.assertTrue(recordsMap.containsKey(_keys.get(2))); // c, c, c
Assert.assertTrue(recordsMap.containsKey(_keys.get(3)));
Assert.assertTrue(recordsMap.containsKey(_keys.get(4)));
// d1 asc, d3 desc (tie breaking with 2nd comparator
- sel1.setColumn("d1");
- sel1.setIsAsc(true);
- sel2.setColumn("d3");
- sel2.setIsAsc(false);
- selectionSort = Lists.newArrayList(sel1, sel2);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("d1");
+ selectionSort1.setIsAsc(true);
+ selectionSort2.setColumn("d3");
+ selectionSort2.setIsAsc(false);
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Arrays.asList(selectionSort1, selectionSort2));
recordsMap = new HashMap<>(_recordsMap);
- _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+ tableResizer.resizeRecordsMap(recordsMap, trimToSize);
Assert.assertEquals(recordsMap.size(), trimToSize);
Assert.assertTrue(recordsMap.containsKey(_keys.get(0))); // 10, 10, 300
Assert.assertTrue(recordsMap.containsKey(_keys.get(1)));
Assert.assertTrue(recordsMap.containsKey(_keys.get(4)));
// d2 asc
- sel1.setColumn("d2");
- sel1.setIsAsc(true);
- selectionSort = Lists.newArrayList(sel1);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("d2");
+ selectionSort1.setIsAsc(true);
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
recordsMap = new HashMap<>(_recordsMap);
- _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+ tableResizer.resizeRecordsMap(recordsMap, trimToSize);
Assert.assertEquals(recordsMap.size(), trimToSize);
Assert.assertTrue(recordsMap.containsKey(_keys.get(0))); // 10, 10, 50
Assert.assertTrue(recordsMap.containsKey(_keys.get(1)));
Assert.assertTrue(recordsMap.containsKey(_keys.get(3)));
// d1 asc, sum(m1) desc, max(m2) desc
- sel1.setColumn("d1");
- sel1.setIsAsc(true);
- sel2.setColumn("sum(m1)");
- sel2.setIsAsc(false);
- sel3.setColumn("max(m2)");
- sel3.setIsAsc(false);
- selectionSort = Lists.newArrayList(sel1, sel2, sel3);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("d1");
+ selectionSort1.setIsAsc(true);
+ selectionSort2.setColumn("sum(m1)");
+ selectionSort2.setIsAsc(false);
+ selectionSort3.setColumn("max(m2)");
+ selectionSort3.setIsAsc(false);
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions,
+ Arrays.asList(selectionSort1, selectionSort2, selectionSort3));
recordsMap = new HashMap<>(_recordsMap);
- _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+ tableResizer.resizeRecordsMap(recordsMap, trimToSize);
Assert.assertEquals(recordsMap.size(), trimToSize);
Assert.assertTrue(recordsMap.containsKey(_keys.get(0))); // a, b, (c (30, 300))
Assert.assertTrue(recordsMap.containsKey(_keys.get(1)));
Assert.assertTrue(recordsMap.containsKey(_keys.get(2)));
// object type avg(m4) asc
- sel1.setColumn("avg(m4)");
- sel1.setIsAsc(true);
- selectionSort = Lists.newArrayList(sel1);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("avg(m4)");
+ selectionSort1.setIsAsc(true);
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
recordsMap = new HashMap<>(_recordsMap);
- _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+ tableResizer.resizeRecordsMap(recordsMap, trimToSize);
Assert.assertEquals(recordsMap.size(), trimToSize);
Assert.assertTrue(recordsMap.containsKey(_keys.get(4))); // 2, 3, 3.33,
Assert.assertTrue(recordsMap.containsKey(_keys.get(3)));
Assert.assertTrue(recordsMap.containsKey(_keys.get(1)));
// non-comparable intermediate result
- sel1.setColumn("distinctcount(m3)");
- sel1.setIsAsc(false);
- sel2.setColumn("d1");
- sel2.setIsAsc(true);
- selectionSort = Lists.newArrayList(sel1, sel2);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("distinctcount(m3)");
+ selectionSort1.setIsAsc(false);
+ selectionSort2.setColumn("d1");
+ selectionSort2.setIsAsc(true);
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Arrays.asList(selectionSort1, selectionSort2));
recordsMap = new HashMap<>(_recordsMap);
- _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+ tableResizer.resizeRecordsMap(recordsMap, trimToSize);
Assert.assertEquals(recordsMap.size(), trimToSize);
Assert.assertTrue(recordsMap.containsKey(_keys.get(4))); // 6, 5, 4 (b)
Assert.assertTrue(recordsMap.containsKey(_keys.get(3)));
@@ -233,36 +201,33 @@ public class TableResizerTest {
trimToSize = 2;
// d1 asc
- sel1.setColumn("d1");
- sel1.setIsAsc(true);
- selectionSort = Lists.newArrayList(sel1);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("d1");
+ selectionSort1.setIsAsc(true);
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
recordsMap = new HashMap<>(_recordsMap);
- _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+ tableResizer.resizeRecordsMap(recordsMap, trimToSize);
Assert.assertEquals(recordsMap.size(), trimToSize);
Assert.assertTrue(recordsMap.containsKey(_keys.get(0))); // a, b
Assert.assertTrue(recordsMap.containsKey(_keys.get(1)));
// object type avg(m4) asc
- sel1.setColumn("avg(m4)");
- sel1.setIsAsc(true);
- selectionSort = Lists.newArrayList(sel1);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("avg(m4)");
+ selectionSort1.setIsAsc(true);
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
recordsMap = new HashMap<>(_recordsMap);
- _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+ tableResizer.resizeRecordsMap(recordsMap, trimToSize);
Assert.assertEquals(recordsMap.size(), trimToSize);
Assert.assertTrue(recordsMap.containsKey(_keys.get(4))); // 2, 3, 3.33,
Assert.assertTrue(recordsMap.containsKey(_keys.get(3)));
// non-comparable intermediate result
- sel1.setColumn("distinctcount(m3)");
- sel1.setIsAsc(false);
- sel2.setColumn("d1");
- sel2.setIsAsc(true);
- selectionSort = Lists.newArrayList(sel1, sel2);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("distinctcount(m3)");
+ selectionSort1.setIsAsc(false);
+ selectionSort2.setColumn("d1");
+ selectionSort2.setIsAsc(true);
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Arrays.asList(selectionSort1, selectionSort2));
recordsMap = new HashMap<>(_recordsMap);
- _tableResizer.resizeRecordsMap(recordsMap, trimToSize);
+ tableResizer.resizeRecordsMap(recordsMap, trimToSize);
Assert.assertEquals(recordsMap.size(), trimToSize);
Assert.assertTrue(recordsMap.containsKey(_keys.get(4))); // 6, 5, 4 (b)
Assert.assertTrue(recordsMap.containsKey(_keys.get(3)));
@@ -279,14 +244,17 @@ public class TableResizerTest {
List<Record> sortedRecords;
int[] order;
Map<Key, Record> recordsMap;
+ SelectionSort selectionSort1 = new SelectionSort();
+ SelectionSort selectionSort2 = new SelectionSort();
+ SelectionSort selectionSort3 = new SelectionSort();
// d1 asc
- sel1.setColumn("d1");
- sel1.setIsAsc(true);
- selectionSort = Lists.newArrayList(sel1);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("d1");
+ selectionSort1.setIsAsc(true);
+ TableResizer tableResizer =
+ new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
recordsMap = new HashMap<>(_recordsMap);
- sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, trimToSize);
+ sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, trimToSize);
Assert.assertEquals(sortedRecords.size(), trimToSize);
order = new int[]{0, 1};
for (int i = 0; i < order.length; i++) {
@@ -295,7 +263,7 @@ public class TableResizerTest {
// d1 asc - trim to 1
recordsMap = new HashMap<>(_recordsMap);
- sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, 1);
+ sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, 1);
Assert.assertEquals(sortedRecords.size(), 1);
order = new int[]{0};
for (int i = 0; i < order.length; i++) {
@@ -303,14 +271,13 @@ public class TableResizerTest {
}
// d1 asc, d3 desc (tie breaking with 2nd comparator)
- sel1.setColumn("d1");
- sel1.setIsAsc(true);
- sel2.setColumn("d3");
- sel2.setIsAsc(false);
- selectionSort = Lists.newArrayList(sel1, sel2);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("d1");
+ selectionSort1.setIsAsc(true);
+ selectionSort2.setColumn("d3");
+ selectionSort2.setIsAsc(false);
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Arrays.asList(selectionSort1, selectionSort2));
recordsMap = new HashMap<>(_recordsMap);
- sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, trimToSize);
+ sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, trimToSize);
Assert.assertEquals(sortedRecords.size(), trimToSize);
order = new int[]{0, 1, 4};
for (int i = 0; i < order.length; i++) {
@@ -319,7 +286,7 @@ public class TableResizerTest {
// d1 asc, d3 desc (tie breaking with 2nd comparator) - trim 1
recordsMap = new HashMap<>(_recordsMap);
- sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, 1);
+ sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, 1);
Assert.assertEquals(sortedRecords.size(), 1);
order = new int[]{0};
for (int i = 0; i < order.length; i++) {
@@ -327,16 +294,16 @@ public class TableResizerTest {
}
// d1 asc, sum(m1) desc, max(m2) desc
- sel1.setColumn("d1");
- sel1.setIsAsc(true);
- sel2.setColumn("sum(m1)");
- sel2.setIsAsc(false);
- sel3.setColumn("max(m2)");
- sel3.setIsAsc(false);
- selectionSort = Lists.newArrayList(sel1, sel2, sel3);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("d1");
+ selectionSort1.setIsAsc(true);
+ selectionSort2.setColumn("sum(m1)");
+ selectionSort2.setIsAsc(false);
+ selectionSort3.setColumn("max(m2)");
+ selectionSort3.setIsAsc(false);
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions,
+ Arrays.asList(selectionSort1, selectionSort2, selectionSort3));
recordsMap = new HashMap<>(_recordsMap);
- sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, trimToSize);
+ sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, trimToSize);
Assert.assertEquals(sortedRecords.size(), trimToSize);
order = new int[]{0, 1, 2};
for (int i = 0; i < order.length; i++) {
@@ -345,7 +312,7 @@ public class TableResizerTest {
// trim 1
recordsMap = new HashMap<>(_recordsMap);
- sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, 1);
+ sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, 1);
Assert.assertEquals(sortedRecords.size(), 1);
order = new int[]{0};
for (int i = 0; i < order.length; i++) {
@@ -353,14 +320,13 @@ public class TableResizerTest {
}
// object type avg(m4) asc
- sel1.setColumn("avg(m4)");
- sel1.setIsAsc(true);
- sel2.setColumn("d1");
- sel2.setIsAsc(true);
- selectionSort = Lists.newArrayList(sel1, sel2);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("avg(m4)");
+ selectionSort1.setIsAsc(true);
+ selectionSort2.setColumn("d1");
+ selectionSort2.setIsAsc(true);
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Arrays.asList(selectionSort1, selectionSort2));
recordsMap = new HashMap<>(_recordsMap);
- sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, 10); // high trim to size
+ sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, 10); // high trim to size
Assert.assertEquals(sortedRecords.size(), recordsMap.size());
order = new int[]{4, 3, 1, 0, 2};
for (int i = 0; i < order.length; i++) {
@@ -368,14 +334,13 @@ public class TableResizerTest {
}
// non-comparable intermediate result
- sel1.setColumn("distinctcount(m3)");
- sel1.setIsAsc(false);
- sel2.setColumn("avg(m4)");
- sel2.setIsAsc(false);
- selectionSort = Lists.newArrayList(sel1, sel2);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("distinctcount(m3)");
+ selectionSort1.setIsAsc(false);
+ selectionSort2.setColumn("avg(m4)");
+ selectionSort2.setIsAsc(false);
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Arrays.asList(selectionSort1, selectionSort2));
recordsMap = new HashMap<>(_recordsMap);
- sortedRecords = _tableResizer.resizeAndSortRecordsMap(recordsMap, recordsMap.size()); // equal trim to size
+ sortedRecords = tableResizer.resizeAndSortRecordsMap(recordsMap, recordsMap.size()); // equal trim to size
Assert.assertEquals(sortedRecords.size(), recordsMap.size());
order = new int[]{4, 3, 2, 1, 0};
for (int i = 0; i < order.length; i++) {
@@ -388,45 +353,43 @@ public class TableResizerTest {
*/
@Test
public void testIntermediateRecord() {
+ SelectionSort selectionSort1 = new SelectionSort();
+ SelectionSort selectionSort2 = new SelectionSort();
+ SelectionSort selectionSort3 = new SelectionSort();
// d2
- sel1.setColumn("d2");
- selectionSort = Lists.newArrayList(sel1);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("d2");
+ TableResizer tableResizer =
+ new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
for (Map.Entry<Key, Record> entry : _recordsMap.entrySet()) {
Key key = entry.getKey();
Record record = entry.getValue();
- TableResizer.IntermediateRecord intermediateRecord =
- _tableResizer.getIntermediateRecord(key, record);
+ TableResizer.IntermediateRecord intermediateRecord = tableResizer.getIntermediateRecord(key, record);
Assert.assertEquals(intermediateRecord._key, key);
Assert.assertEquals(intermediateRecord._values.length, 1);
Assert.assertEquals(intermediateRecord._values[0], record.getValues()[1]);
}
// sum(m1)
- sel1.setColumn("sum(m1)");
- selectionSort = Lists.newArrayList(sel1);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("sum(m1)");
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
for (Map.Entry<Key, Record> entry : _recordsMap.entrySet()) {
Key key = entry.getKey();
Record record = entry.getValue();
- TableResizer.IntermediateRecord intermediateRecord =
- _tableResizer.getIntermediateRecord(key, record);
+ TableResizer.IntermediateRecord intermediateRecord = tableResizer.getIntermediateRecord(key, record);
Assert.assertEquals(intermediateRecord._key, key);
Assert.assertEquals(intermediateRecord._values.length, 1);
Assert.assertEquals(intermediateRecord._values[0], record.getValues()[3]);
}
// d1, max(m2)
- sel1.setColumn("d1");
- sel2.setColumn("max(m2)");
- selectionSort = Lists.newArrayList(sel1, sel2);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("d1");
+ selectionSort2.setColumn("max(m2)");
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Arrays.asList(selectionSort1, selectionSort2));
for (Map.Entry<Key, Record> entry : _recordsMap.entrySet()) {
Key key = entry.getKey();
Record record = entry.getValue();
- TableResizer.IntermediateRecord intermediateRecord =
- _tableResizer.getIntermediateRecord(key, record);
+ TableResizer.IntermediateRecord intermediateRecord = tableResizer.getIntermediateRecord(key, record);
Assert.assertEquals(intermediateRecord._key, key);
Assert.assertEquals(intermediateRecord._values.length, 2);
Assert.assertEquals(intermediateRecord._values[0], record.getValues()[0]);
@@ -434,16 +397,15 @@ public class TableResizerTest {
}
// d2, sum(m1), d3
- sel1.setColumn("d2");
- sel2.setColumn("sum(m1)");
- sel3.setColumn("d3");
- selectionSort = Lists.newArrayList(sel1, sel2, sel3);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
+ selectionSort1.setColumn("d2");
+ selectionSort2.setColumn("sum(m1)");
+ selectionSort3.setColumn("d3");
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions,
+ Arrays.asList(selectionSort1, selectionSort2, selectionSort3));
for (Map.Entry<Key, Record> entry : _recordsMap.entrySet()) {
Key key = entry.getKey();
Record record = entry.getValue();
- TableResizer.IntermediateRecord intermediateRecord =
- _tableResizer.getIntermediateRecord(key, record);
+ TableResizer.IntermediateRecord intermediateRecord = tableResizer.getIntermediateRecord(key, record);
Assert.assertEquals(intermediateRecord._key, key);
Assert.assertEquals(intermediateRecord._values.length, 3);
Assert.assertEquals(intermediateRecord._values[0], record.getValues()[1]);
@@ -452,20 +414,17 @@ public class TableResizerTest {
}
// non-comparable intermediate result
- sel1.setColumn("distinctcount(m3)");
- selectionSort = Lists.newArrayList(sel1);
- _tableResizer = new TableResizer(dataSchema, aggregationInfos, selectionSort);
- AggregationFunction distinctCountFunction =
- AggregationFunctionUtils.getAggregationFunctionContext(aggregationInfos.get(2)).getAggregationFunction();
+ selectionSort1.setColumn("distinctcount(m3)");
+ tableResizer = new TableResizer(_dataSchema, _aggregationFunctions, Collections.singletonList(selectionSort1));
+ AggregationFunction distinctCountFunction = _aggregationFunctions[2];
for (Map.Entry<Key, Record> entry : _recordsMap.entrySet()) {
Key key = entry.getKey();
Record record = entry.getValue();
- TableResizer.IntermediateRecord intermediateRecord =
- _tableResizer.getIntermediateRecord(key, record);
+ TableResizer.IntermediateRecord intermediateRecord = tableResizer.getIntermediateRecord(key, record);
Assert.assertEquals(intermediateRecord._key, key);
Assert.assertEquals(intermediateRecord._values.length, 1);
- Assert.assertEquals(intermediateRecord._values[0],
- distinctCountFunction.extractFinalResult(record.getValues()[5]));
+ Assert
+ .assertEquals(intermediateRecord._values[0], distinctCountFunction.extractFinalResult(record.getValues()[5]));
}
}
@@ -479,7 +438,7 @@ public class TableResizerTest {
selectionSort.setColumn("STRING_COL");
selectionSort.setIsAsc(true);
- TableResizer tableResizer = new TableResizer(schema, Collections.emptyList(), Lists.newArrayList(selectionSort));
+ TableResizer tableResizer = new TableResizer(schema, new AggregationFunction[0], Lists.newArrayList(selectionSort));
Set<Record> uniqueRecordsSet = new HashSet<>();
Record r1 = new Record(new Object[]{"B"});
@@ -559,7 +518,7 @@ public class TableResizerTest {
// change the order to DESC
selectionSort.setIsAsc(false);
- tableResizer = new TableResizer(schema, Collections.emptyList(), Lists.newArrayList(selectionSort));
+ tableResizer = new TableResizer(schema, new AggregationFunction[0], Lists.newArrayList(selectionSort));
trimSize = 5;
// no records should have been evicted
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
index 442d86b..9c7dfed 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java
@@ -20,7 +20,6 @@ package org.apache.pinot.core.query.aggregation.function;
import java.util.Arrays;
import java.util.Collections;
-import java.util.List;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.BrokerRequest;
@@ -36,7 +35,7 @@ public class AggregationFunctionFactoryTest {
AggregationFunction aggregationFunction;
BrokerRequest brokerRequest = new BrokerRequest();
- String column = "testColumn";
+ String column;
AggregationInfo aggregationInfo = new AggregationInfo();
aggregationInfo.setAggregationType("CoUnT");
@@ -270,9 +269,9 @@ public class AggregationFunctionFactoryTest {
AggregationInfo aggregationInfo = new AggregationInfo();
aggregationInfo.setAggregationType("distinct");
- List<String> args = Arrays.asList("column1", "column2", "column3");
- String expected = "distinct_" + AggregationFunctionUtils.concatArgs(args);
- aggregationInfo.setExpressions(args);
+ String[] arguments = new String[]{"column1", "column2", "column3"};
+ String expected = "distinct_" + AggregationFunctionUtils.concatArgs(arguments);
+ aggregationInfo.setExpressions(Arrays.asList(arguments));
AggregationFunction aggregationFunction =
AggregationFunctionFactory.getAggregationFunction(aggregationInfo, brokerRequest);
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java
index 03177ae..daa5ddb 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/startree/v2/BaseStarTreeV2Test.java
@@ -30,12 +30,7 @@ import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.apache.commons.io.FileUtils;
-import org.apache.pinot.spi.config.table.TableConfig;
-import org.apache.pinot.spi.config.table.TableType;
-import org.apache.pinot.spi.data.FieldSpec.DataType;
-import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.common.function.AggregationFunctionType;
-import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.BrokerRequest;
import org.apache.pinot.common.request.GroupBy;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
@@ -46,7 +41,6 @@ import org.apache.pinot.core.common.BlockDocIdIterator;
import org.apache.pinot.core.common.BlockSingleValIterator;
import org.apache.pinot.core.common.Constants;
import org.apache.pinot.core.common.DataSource;
-import org.apache.pinot.spi.data.readers.GenericRow;
import org.apache.pinot.core.data.aggregator.ValueAggregator;
import org.apache.pinot.core.data.readers.GenericRowRecordReader;
import org.apache.pinot.core.indexsegment.IndexSegment;
@@ -54,6 +48,7 @@ import org.apache.pinot.core.indexsegment.generator.SegmentGeneratorConfig;
import org.apache.pinot.core.indexsegment.immutable.ImmutableSegmentLoader;
import org.apache.pinot.core.plan.FilterPlanNode;
import org.apache.pinot.core.plan.PlanNode;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.segment.creator.impl.SegmentIndexCreationDriverImpl;
import org.apache.pinot.core.segment.index.readers.Dictionary;
@@ -62,6 +57,11 @@ import org.apache.pinot.core.startree.v2.builder.MultipleTreesBuilder;
import org.apache.pinot.core.startree.v2.builder.MultipleTreesBuilder.BuildMode;
import org.apache.pinot.core.startree.v2.builder.StarTreeV2BuilderConfig;
import org.apache.pinot.pql.parsers.Pql2Compiler;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.data.readers.GenericRow;
import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
@@ -69,6 +69,7 @@ import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
/**
@@ -118,14 +119,12 @@ abstract class BaseStarTreeV2Test<R, A> {
List<GenericRow> segmentRecords = new ArrayList<>(NUM_SEGMENT_RECORDS);
for (int i = 0; i < NUM_SEGMENT_RECORDS; i++) {
- Map<String, Object> fieldMap = new HashMap<>();
- fieldMap.put(DIMENSION_D1, RANDOM.nextInt(DIMENSION_CARDINALITY));
- fieldMap.put(DIMENSION_D2, RANDOM.nextInt(DIMENSION_CARDINALITY));
+ GenericRow segmentRecord = new GenericRow();
+ segmentRecord.putValue(DIMENSION_D1, RANDOM.nextInt(DIMENSION_CARDINALITY));
+ segmentRecord.putValue(DIMENSION_D2, RANDOM.nextInt(DIMENSION_CARDINALITY));
if (rawValueType != null) {
- fieldMap.put(METRIC, getRandomRawValue(RANDOM));
+ segmentRecord.putValue(METRIC, getRandomRawValue(RANDOM));
}
- GenericRow segmentRecord = new GenericRow();
- segmentRecord.init(fieldMap);
segmentRecords.add(segmentRecord);
}
@@ -184,11 +183,14 @@ abstract class BaseStarTreeV2Test<R, A> {
BrokerRequest brokerRequest = COMPILER.compileToBrokerRequest(query);
// Aggregations
- List<AggregationInfo> aggregationInfos = brokerRequest.getAggregationsInfo();
- int numAggregations = aggregationInfos.size();
+ AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(brokerRequest);
+ int numAggregations = aggregationFunctions.length;
List<AggregationFunctionColumnPair> functionColumnPairs = new ArrayList<>(numAggregations);
- for (AggregationInfo aggregationInfo : aggregationInfos) {
- functionColumnPairs.add(AggregationFunctionUtils.getFunctionColumnPair(aggregationInfo));
+ for (AggregationFunction aggregationFunction : aggregationFunctions) {
+ AggregationFunctionColumnPair aggregationFunctionColumnPair =
+ AggregationFunctionUtils.getAggregationFunctionColumnPair(aggregationFunction);
+ assertNotNull(aggregationFunctionColumnPair);
+ functionColumnPairs.add(aggregationFunctionColumnPair);
}
// Group-by columns
diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java
index efc3d40..3ba51c2 100644
--- a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationMultiValueQueriesTest.java
@@ -71,33 +71,41 @@ public class InterSegmentAggregationMultiValueQueriesTest extends BaseMultiValue
aggregations.add(new String[]{"35436"});
aggregations.add(new String[]{"33576"});
aggregations.add(new String[]{"24300"});
- QueriesTestUtils
- .testInterSegmentAggregationGroupByResult(brokerResponse, 400000L, 0L, 800000L, 400000L,
- groupKeys, aggregations);
+ QueriesTestUtils.testInterSegmentAggregationGroupByResult(brokerResponse, 400000L, 0L, 800000L, 400000L, groupKeys,
+ aggregations);
- query = "SELECT VALUEIN(column7, 363, 469, 246, 100000), COUNTMV(column6) FROM testTable GROUP BY VALUEIN(column7, 363, 469, 246, 100000)";
+ query =
+ "SELECT VALUEIN(column7, 363, 469, 246, 100000), COUNTMV(column6) FROM testTable GROUP BY VALUEIN(column7, 363, 469, 246, 100000)";
brokerResponse = getBrokerResponseForSqlQuery(query);
DataSchema expectedDataSchema =
- new DataSchema(new String[]{"valuein(column7,'363','469','246','100000')", "countmv(column6)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.LONG});
- QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L,
- Lists.newArrayList(new Object[]{469, (long)33576}, new Object[]{246, (long)24300}, new Object[]{363, (long)35436}), 3, expectedDataSchema);
-
- query = "SELECT VALUEIN(column7, 363, 469, 246, 100000), COUNTMV(column6) FROM testTable GROUP BY VALUEIN(column7, 363, 469, 246, 100000) ORDER BY COUNTMV(column6)";
+ new DataSchema(new String[]{"valuein(column7,'363','469','246','100000')", "countmv(column6)"},
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.LONG});
+ QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, Lists
+ .newArrayList(new Object[]{469, (long) 33576}, new Object[]{246, (long) 24300},
+ new Object[]{363, (long) 35436}), 3, expectedDataSchema);
+
+ query =
+ "SELECT VALUEIN(column7, 363, 469, 246, 100000), COUNTMV(column6) FROM testTable GROUP BY VALUEIN(column7, 363, 469, 246, 100000) ORDER BY COUNTMV(column6)";
brokerResponse = getBrokerResponseForSqlQuery(query);
- QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L,
- Lists.newArrayList(new Object[]{246, (long)24300}, new Object[]{469, (long)33576}, new Object[]{363, (long)35436}), 3, expectedDataSchema);
+ QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, Lists
+ .newArrayList(new Object[]{246, (long) 24300}, new Object[]{469, (long) 33576},
+ new Object[]{363, (long) 35436}), 3, expectedDataSchema);
- query = "SELECT VALUEIN(column7, 363, 469, 246, 100000), COUNTMV(column6) FROM testTable GROUP BY VALUEIN(column7, 363, 469, 246, 100000) ORDER BY COUNTMV(column6) DESC";
+ query =
+ "SELECT VALUEIN(column7, 363, 469, 246, 100000), COUNTMV(column6) FROM testTable GROUP BY VALUEIN(column7, 363, 469, 246, 100000) ORDER BY COUNTMV(column6) DESC";
brokerResponse = getBrokerResponseForSqlQuery(query);
- QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L,
- Lists.newArrayList(new Object[]{363, (long)35436}, new Object[]{469, (long)33576}, new Object[]{246, (long)24300}), 3, expectedDataSchema);
-
- query = "SELECT VALUEIN(column7, 363, 469, 246, 100000) AS value_in_col, COUNTMV(column6) FROM testTable GROUP BY value_in_col ORDER BY COUNTMV(column6) DESC";
- expectedDataSchema =
- new DataSchema(new String[]{"value_in_col", "countmv(column6)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.LONG});
+ QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, Lists
+ .newArrayList(new Object[]{363, (long) 35436}, new Object[]{469, (long) 33576},
+ new Object[]{246, (long) 24300}), 3, expectedDataSchema);
+
+ query =
+ "SELECT VALUEIN(column7, 363, 469, 246, 100000) AS value_in_col, COUNTMV(column6) FROM testTable GROUP BY value_in_col ORDER BY COUNTMV(column6) DESC";
+ expectedDataSchema = new DataSchema(new String[]{"value_in_col", "countmv(column6)"},
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.LONG});
brokerResponse = getBrokerResponseForSqlQuery(query);
- QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L,
- Lists.newArrayList(new Object[]{363, (long)35436}, new Object[]{469, (long)33576}, new Object[]{246, (long)24300}), 3, expectedDataSchema);
+ QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, Lists
+ .newArrayList(new Object[]{363, (long) 35436}, new Object[]{469, (long) 33576},
+ new Object[]{246, (long) 24300}), 3, expectedDataSchema);
query = "SELECT COUNTMV(column6) FROM testTable GROUP BY daysSinceEpoch";
brokerResponse = getBrokerResponseForPqlQuery(query);
@@ -105,18 +113,18 @@ public class InterSegmentAggregationMultiValueQueriesTest extends BaseMultiValue
groupKeys.add(new String[]{"1756015683"});
aggregations = new ArrayList<>();
aggregations.add(new String[]{"426752"});
- QueriesTestUtils
- .testInterSegmentAggregationGroupByResult(brokerResponse, 400000L, 0L, 800000L, 400000L,
- groupKeys, aggregations);
+ QueriesTestUtils.testInterSegmentAggregationGroupByResult(brokerResponse, 400000L, 0L, 800000L, 400000L, groupKeys,
+ aggregations);
- query = "SELECT daysSinceEpoch, COUNTMV(column6) FROM testTable GROUP BY daysSinceEpoch ORDER BY COUNTMV(column6) DESC";
- expectedDataSchema =
- new DataSchema(new String[]{"daysSinceEpoch", "countmv(column6)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.LONG});
+ query =
+ "SELECT daysSinceEpoch, COUNTMV(column6) FROM testTable GROUP BY daysSinceEpoch ORDER BY COUNTMV(column6) DESC";
+ expectedDataSchema = new DataSchema(new String[]{"daysSinceEpoch", "countmv(column6)"},
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.LONG});
brokerResponse = getBrokerResponseForSqlQuery(query);
List<Object[]> result = new ArrayList<>();
- result.add(new Object[]{1756015683, (long)426752});
- QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L,
- result, 1, expectedDataSchema);
+ result.add(new Object[]{1756015683, (long) 426752});
+ QueriesTestUtils
+ .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, result, 1, expectedDataSchema);
query = "SELECT COUNTMV(column6) FROM testTable GROUP BY timeconvert(daysSinceEpoch, 'DAYS', 'HOURS')";
brokerResponse = getBrokerResponseForPqlQuery(query);
@@ -124,18 +132,18 @@ public class InterSegmentAggregationMultiValueQueriesTest extends BaseMultiValue
groupKeys.add(new String[]{"42144376392"});
aggregations = new ArrayList<>();
aggregations.add(new String[]{"426752"});
- QueriesTestUtils
- .testInterSegmentAggregationGroupByResult(brokerResponse, 400000L, 0L, 800000L, 400000L,
- groupKeys, aggregations);
+ QueriesTestUtils.testInterSegmentAggregationGroupByResult(brokerResponse, 400000L, 0L, 800000L, 400000L, groupKeys,
+ aggregations);
- query = "SELECT timeconvert(daysSinceEpoch, 'DAYS', 'HOURS'), COUNTMV(column6) FROM testTable GROUP BY timeconvert(daysSinceEpoch, 'DAYS', 'HOURS') ORDER BY COUNTMV(column6) DESC";
- expectedDataSchema =
- new DataSchema(new String[]{"timeconvert(daysSinceEpoch,'DAYS','HOURS')", "countmv(column6)"}, new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.LONG, DataSchema.ColumnDataType.LONG});
+ query =
+ "SELECT timeconvert(daysSinceEpoch, 'DAYS', 'HOURS'), COUNTMV(column6) FROM testTable GROUP BY timeconvert(daysSinceEpoch, 'DAYS', 'HOURS') ORDER BY COUNTMV(column6) DESC";
+ expectedDataSchema = new DataSchema(new String[]{"timeconvert(daysSinceEpoch,'DAYS','HOURS')", "countmv(column6)"},
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.LONG, DataSchema.ColumnDataType.LONG});
brokerResponse = getBrokerResponseForSqlQuery(query);
result = new ArrayList<>();
- result.add(new Object[]{42144376392L, (long)426752});
- QueriesTestUtils.testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L,
- result, 1, expectedDataSchema);
+ result.add(new Object[]{42144376392L, (long) 426752});
+ QueriesTestUtils
+ .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, result, 1, expectedDataSchema);
}
@Test
@@ -340,23 +348,27 @@ public class InterSegmentAggregationMultiValueQueriesTest extends BaseMultiValue
@Test
public void testPercentile50MV() {
- String query = "SELECT PERCENTILE50MV(column6) FROM testTable";
-
- BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
- QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 400000L, 0L, 400000L, 400000L,
- new String[]{"2147483647.00000"});
-
- brokerResponse = getBrokerResponseForPqlQueryWithFilter(query);
- QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 62480L, 1101664L, 62480L, 400000L,
- new String[]{"2147483647.00000"});
-
- brokerResponse = getBrokerResponseForPqlQuery(query + SV_GROUP_BY);
- QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 400000L, 0L, 800000L, 400000L,
- new String[]{"2147483647.00000"});
-
- brokerResponse = getBrokerResponseForPqlQuery(query + MV_GROUP_BY);
- QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 400000L, 0L, 800000L, 400000L,
- new String[]{"2147483647.00000"});
+ List<String> queries = Arrays
+ .asList("SELECT PERCENTILE50MV(column6) FROM testTable", "SELECT PERCENTILEMV(column6, 50) FROM testTable",
+ "SELECT PERCENTILEMV(column6, '50') FROM testTable", "SELECT PERCENTILEMV(column6, \"50\") FROM testTable");
+
+ for (String query : queries) {
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+ QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 400000L, 0L, 400000L, 400000L,
+ new String[]{"2147483647.00000"});
+
+ brokerResponse = getBrokerResponseForPqlQueryWithFilter(query);
+ QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 62480L, 1101664L, 62480L, 400000L,
+ new String[]{"2147483647.00000"});
+
+ brokerResponse = getBrokerResponseForPqlQuery(query + SV_GROUP_BY);
+ QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 400000L, 0L, 800000L, 400000L,
+ new String[]{"2147483647.00000"});
+
+ brokerResponse = getBrokerResponseForPqlQuery(query + MV_GROUP_BY);
+ QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 400000L, 0L, 800000L, 400000L,
+ new String[]{"2147483647.00000"});
+ }
}
@Test
diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationSingleValueQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationSingleValueQueriesTest.java
index 8b18adb..70654d4 100644
--- a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationSingleValueQueriesTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentAggregationSingleValueQueriesTest.java
@@ -19,6 +19,8 @@
package org.apache.pinot.queries;
import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
import java.util.function.Function;
import org.apache.pinot.common.response.broker.BrokerResponseNative;
import org.apache.pinot.common.response.broker.SelectionResults;
@@ -238,23 +240,28 @@ public class InterSegmentAggregationSingleValueQueriesTest extends BaseSingleVal
@Test
public void testPercentile50() {
- String query = "SELECT PERCENTILE50(column1), PERCENTILE50(column3) FROM testTable";
-
- BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
- QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 120000L, 0L, 240000L, 120000L,
- new String[]{"1107310944.00000", "1080136306.00000"});
-
- brokerResponse = getBrokerResponseForPqlQueryWithFilter(query);
- QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 24516L, 336536L, 49032L, 120000L,
- new String[]{"1139674505.00000", "505053732.00000"});
-
- brokerResponse = getBrokerResponseForPqlQuery(query + GROUP_BY);
- QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 120000L, 0L, 360000L, 120000L,
- new String[]{"2146791843.00000", "2141451242.00000"});
-
- brokerResponse = getBrokerResponseForPqlQueryWithFilter(query + GROUP_BY);
- QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 24516L, 336536L, 73548L, 120000L,
- new String[]{"2142595699.00000", "999309554.00000"});
+ List<String> queries = Arrays.asList("SELECT PERCENTILE50(column1), PERCENTILE50(column3) FROM testTable",
+ "SELECT PERCENTILE(column1, 50), PERCENTILE(column3, 50) FROM testTable",
+ "SELECT PERCENTILE(column1, '50'), PERCENTILE(column3, '50') FROM testTable",
+ "SELECT PERCENTILE(column1, \"50\"), PERCENTILE(column3, \"50\") FROM testTable");
+
+ for (String query : queries) {
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query);
+ QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 120000L, 0L, 240000L, 120000L,
+ new String[]{"1107310944.00000", "1080136306.00000"});
+
+ brokerResponse = getBrokerResponseForPqlQueryWithFilter(query);
+ QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 24516L, 336536L, 49032L, 120000L,
+ new String[]{"1139674505.00000", "505053732.00000"});
+
+ brokerResponse = getBrokerResponseForPqlQuery(query + GROUP_BY);
+ QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 120000L, 0L, 360000L, 120000L,
+ new String[]{"2146791843.00000", "2141451242.00000"});
+
+ brokerResponse = getBrokerResponseForPqlQueryWithFilter(query + GROUP_BY);
+ QueriesTestUtils.testInterSegmentAggregationResult(brokerResponse, 24516L, 336536L, 73548L, 120000L,
+ new String[]{"2142595699.00000", "999309554.00000"});
+ }
}
@Test
diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableMultiValueQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableMultiValueQueriesTest.java
index dcf5d7b..3026b0f 100644
--- a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableMultiValueQueriesTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableMultiValueQueriesTest.java
@@ -19,6 +19,7 @@
package org.apache.pinot.queries;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -471,48 +472,52 @@ public class InterSegmentResultTableMultiValueQueriesTest extends BaseMultiValue
@Test
public void testPercentile50MV() {
+ List<String> queries = Arrays
+ .asList("SELECT PERCENTILE50MV(column6) FROM testTable", "SELECT PERCENTILEMV(column6, 50) FROM testTable",
+ "SELECT PERCENTILEMV(column6, '50') FROM testTable", "SELECT PERCENTILEMV(column6, \"50\") FROM testTable");
+
DataSchema dataSchema;
List<Object[]> rows;
int expectedResultsSize;
- String query = "SELECT PERCENTILE50MV(column6) FROM testTable";
Map<String, String> queryOptions = new HashMap<>(2);
queryOptions.put(CommonConstants.Broker.Request.QueryOptionKey.RESPONSE_FORMAT, CommonConstants.Broker.Request.SQL);
-
- BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query, queryOptions);
- dataSchema = new DataSchema(new String[]{"percentile50mv(column6)"},
- new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.DOUBLE});
- rows = new ArrayList<>();
- rows.add(new Object[]{2147483647.0});
- expectedResultsSize = 1;
- QueriesTestUtils
- .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 400000L, 400000L, rows, expectedResultsSize,
- dataSchema);
-
- brokerResponse = getBrokerResponseForPqlQuery(query + getFilter(), queryOptions);
- rows = new ArrayList<>();
- rows.add(new Object[]{2147483647.0});
- QueriesTestUtils
- .testInterSegmentResultTable(brokerResponse, 62480L, 1101664L, 62480L, 400000L, rows, expectedResultsSize,
- dataSchema);
-
- brokerResponse = getBrokerResponseForPqlQuery(query + SV_GROUP_BY, queryOptions);
- dataSchema = new DataSchema(new String[]{"column8", "percentile50mv(column6)"},
- new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.DOUBLE});
- rows = new ArrayList<>();
- rows.add(new Object[]{"169878844", 2147483647.0});
- expectedResultsSize = 10;
- QueriesTestUtils
- .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, rows, expectedResultsSize,
- dataSchema);
-
- brokerResponse = getBrokerResponseForPqlQuery(query + MV_GROUP_BY, queryOptions);
- dataSchema = new DataSchema(new String[]{"column7", "percentile50mv(column6)"},
- new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.DOUBLE});
- rows = new ArrayList<>();
- rows.add(new Object[]{"372", 2147483647.0});
- QueriesTestUtils
- .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, rows, expectedResultsSize,
- dataSchema);
+ for (String query : queries) {
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query, queryOptions);
+ dataSchema = new DataSchema(new String[]{"percentile50mv(column6)"},
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.DOUBLE});
+ rows = new ArrayList<>();
+ rows.add(new Object[]{2147483647.0});
+ expectedResultsSize = 1;
+ QueriesTestUtils
+ .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 400000L, 400000L, rows, expectedResultsSize,
+ dataSchema);
+
+ brokerResponse = getBrokerResponseForPqlQuery(query + getFilter(), queryOptions);
+ rows = new ArrayList<>();
+ rows.add(new Object[]{2147483647.0});
+ QueriesTestUtils
+ .testInterSegmentResultTable(brokerResponse, 62480L, 1101664L, 62480L, 400000L, rows, expectedResultsSize,
+ dataSchema);
+
+ brokerResponse = getBrokerResponseForPqlQuery(query + SV_GROUP_BY, queryOptions);
+ dataSchema = new DataSchema(new String[]{"column8", "percentile50mv(column6)"},
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.DOUBLE});
+ rows = new ArrayList<>();
+ rows.add(new Object[]{"169878844", 2147483647.0});
+ expectedResultsSize = 10;
+ QueriesTestUtils
+ .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, rows, expectedResultsSize,
+ dataSchema);
+
+ brokerResponse = getBrokerResponseForPqlQuery(query + MV_GROUP_BY, queryOptions);
+ dataSchema = new DataSchema(new String[]{"column7", "percentile50mv(column6)"},
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.DOUBLE});
+ rows = new ArrayList<>();
+ rows.add(new Object[]{"372", 2147483647.0});
+ QueriesTestUtils
+ .testInterSegmentResultTable(brokerResponse, 400000L, 0L, 800000L, 400000L, rows, expectedResultsSize,
+ dataSchema);
+ }
}
@Test
diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableSingleValueQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableSingleValueQueriesTest.java
index fe6c6df..efa941c 100644
--- a/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableSingleValueQueriesTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/InterSegmentResultTableSingleValueQueriesTest.java
@@ -20,6 +20,7 @@ package org.apache.pinot.queries;
import com.google.common.collect.Lists;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -482,47 +483,52 @@ public class InterSegmentResultTableSingleValueQueriesTest extends BaseSingleVal
@Test
public void testPercentile50() {
+ List<String> queries = Arrays.asList("SELECT PERCENTILE50(column1), PERCENTILE50(column3) FROM testTable",
+ "SELECT PERCENTILE(column1, 50), PERCENTILE(column3, 50) FROM testTable",
+ "SELECT PERCENTILE(column1, '50'), PERCENTILE(column3, '50') FROM testTable",
+ "SELECT PERCENTILE(column1, \"50\"), PERCENTILE(column3, \"50\") FROM testTable");
+
DataSchema dataSchema;
List<Object[]> rows;
int expectedResultsSize;
- String query = "SELECT PERCENTILE50(column1), PERCENTILE50(column3) FROM testTable";
Map<String, String> queryOptions = new HashMap<>(2);
queryOptions.put(QueryOptionKey.RESPONSE_FORMAT, Request.SQL);
-
- BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query, queryOptions);
- dataSchema = new DataSchema(new String[]{"percentile50(column1)", "percentile50(column3)"},
- new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE});
- rows = new ArrayList<>();
- rows.add(new Object[]{1107310944.0, 1080136306.0});
- expectedResultsSize = 1;
- QueriesTestUtils
- .testInterSegmentResultTable(brokerResponse, 120000L, 0L, 240000L, 120000L, rows, expectedResultsSize,
- dataSchema);
-
- brokerResponse = getBrokerResponseForPqlQuery(query + getFilter(), queryOptions);
- rows = new ArrayList<>();
- rows.add(new Object[]{1139674505.0, 505053732.0});
- QueriesTestUtils
- .testInterSegmentResultTable(brokerResponse, 24516L, 336536L, 49032L, 120000L, rows, expectedResultsSize,
- dataSchema);
-
- query = "SELECT PERCENTILE50(column3) FROM testTable";
- brokerResponse = getBrokerResponseForPqlQuery(query + GROUP_BY, queryOptions);
- dataSchema = new DataSchema(new String[]{"column9", "percentile50(column3)"},
- new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.DOUBLE});
- rows = new ArrayList<>();
- rows.add(new Object[]{"1642909995", 2141451242.0});
- expectedResultsSize = 10;
- QueriesTestUtils
- .testInterSegmentResultTable(brokerResponse, 120000L, 0L, 240000L, 120000L, rows, expectedResultsSize,
- dataSchema);
-
- brokerResponse = getBrokerResponseForPqlQuery(query + GROUP_BY + getFilter(), queryOptions);
- rows = new ArrayList<>();
- rows.add(new Object[]{"438926263", 999309554.0});
- QueriesTestUtils
- .testInterSegmentResultTable(brokerResponse, 24516L, 336536L, 49032L, 120000L, rows, expectedResultsSize,
- dataSchema);
+ for (String query : queries) {
+ BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query, queryOptions);
+ dataSchema = new DataSchema(new String[]{"percentile50(column1)", "percentile50(column3)"},
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE});
+ rows = new ArrayList<>();
+ rows.add(new Object[]{1107310944.0, 1080136306.0});
+ expectedResultsSize = 1;
+ QueriesTestUtils
+ .testInterSegmentResultTable(brokerResponse, 120000L, 0L, 240000L, 120000L, rows, expectedResultsSize,
+ dataSchema);
+
+ brokerResponse = getBrokerResponseForPqlQuery(query + getFilter(), queryOptions);
+ rows = new ArrayList<>();
+ rows.add(new Object[]{1139674505.0, 505053732.0});
+ QueriesTestUtils
+ .testInterSegmentResultTable(brokerResponse, 24516L, 336536L, 49032L, 120000L, rows, expectedResultsSize,
+ dataSchema);
+
+ query = "SELECT PERCENTILE50(column3) FROM testTable";
+ brokerResponse = getBrokerResponseForPqlQuery(query + GROUP_BY, queryOptions);
+ dataSchema = new DataSchema(new String[]{"column9", "percentile50(column3)"},
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.DOUBLE});
+ rows = new ArrayList<>();
+ rows.add(new Object[]{"1642909995", 2141451242.0});
+ expectedResultsSize = 10;
+ QueriesTestUtils
+ .testInterSegmentResultTable(brokerResponse, 120000L, 0L, 240000L, 120000L, rows, expectedResultsSize,
+ dataSchema);
+
+ brokerResponse = getBrokerResponseForPqlQuery(query + GROUP_BY + getFilter(), queryOptions);
+ rows = new ArrayList<>();
+ rows.add(new Object[]{"438926263", 999309554.0});
+ QueriesTestUtils
+ .testInterSegmentResultTable(brokerResponse, 24516L, 336536L, 49032L, 120000L, rows, expectedResultsSize,
+ dataSchema);
+ }
}
@Test
diff --git a/pinot-core/src/test/java/org/apache/pinot/query/aggregation/DefaultAggregationExecutorTest.java b/pinot-core/src/test/java/org/apache/pinot/query/aggregation/DefaultAggregationExecutorTest.java
index ee13450..1371f9a 100644
--- a/pinot-core/src/test/java/org/apache/pinot/query/aggregation/DefaultAggregationExecutorTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/query/aggregation/DefaultAggregationExecutorTest.java
@@ -20,7 +20,6 @@ package org.apache.pinot.query.aggregation;
import java.io.File;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@@ -28,7 +27,7 @@ import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.apache.commons.io.FileUtils;
-import org.apache.pinot.common.request.AggregationInfo;
+import org.apache.pinot.common.request.BrokerRequest;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.common.segment.ReadMode;
import org.apache.pinot.core.common.DataSource;
@@ -43,10 +42,11 @@ import org.apache.pinot.core.operator.filter.MatchAllFilterOperator;
import org.apache.pinot.core.operator.transform.TransformOperator;
import org.apache.pinot.core.plan.DocIdSetPlanNode;
import org.apache.pinot.core.query.aggregation.AggregationExecutor;
-import org.apache.pinot.core.query.aggregation.AggregationFunctionContext;
import org.apache.pinot.core.query.aggregation.DefaultAggregationExecutor;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils;
import org.apache.pinot.core.segment.creator.impl.SegmentIndexCreationDriverImpl;
+import org.apache.pinot.pql.parsers.Pql2Compiler;
import org.apache.pinot.spi.config.table.TableType;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.MetricFieldSpec;
@@ -88,8 +88,8 @@ public class DefaultAggregationExecutorTest {
public static IndexSegment _indexSegment;
private Random _random;
- private List<AggregationInfo> _aggregationInfoList;
private String[] _columns;
+ private BrokerRequest _brokerRequest;
private double[][] _inputData;
/**
@@ -110,15 +110,15 @@ public class DefaultAggregationExecutorTest {
_columns = new String[numColumns];
setupSegment();
- _aggregationInfoList = new ArrayList<>();
-
- for (int i = 0; i < _columns.length; i++) {
- AggregationInfo aggregationInfo = new AggregationInfo();
- aggregationInfo.setAggregationType(AGGREGATION_FUNCTIONS[i]);
-
- aggregationInfo.setExpressions(Collections.singletonList(_columns[i]));
- _aggregationInfoList.add(aggregationInfo);
+ StringBuilder queryBuilder = new StringBuilder("SELECT");
+ for (int i = 0; i < numColumns; i++) {
+ queryBuilder.append(String.format(" %s(%s)", AGGREGATION_FUNCTIONS[i], _columns[i]));
+ if (i != numColumns - 1) {
+ queryBuilder.append(',');
+ }
}
+ queryBuilder.append(" FROM testTable");
+ _brokerRequest = new Pql2Compiler().compileToBrokerRequest(queryBuilder.toString());
}
/**
@@ -139,13 +139,8 @@ public class DefaultAggregationExecutorTest {
ProjectionOperator projectionOperator = new ProjectionOperator(dataSourceMap, docIdSetOperator);
TransformOperator transformOperator = new TransformOperator(projectionOperator, expressionTrees);
TransformBlock transformBlock = transformOperator.nextBlock();
- int numAggFuncs = _aggregationInfoList.size();
- AggregationFunctionContext[] aggrFuncContextArray = new AggregationFunctionContext[numAggFuncs];
- for (int i = 0; i < numAggFuncs; i++) {
- AggregationInfo aggregationInfo = _aggregationInfoList.get(i);
- aggrFuncContextArray[i] = AggregationFunctionUtils.getAggregationFunctionContext(aggregationInfo);
- }
- AggregationExecutor aggregationExecutor = new DefaultAggregationExecutor(aggrFuncContextArray);
+ AggregationFunction[] aggregationFunctions = AggregationFunctionUtils.getAggregationFunctions(_brokerRequest);
+ AggregationExecutor aggregationExecutor = new DefaultAggregationExecutor(aggregationFunctions);
aggregationExecutor.aggregate(transformBlock);
List<Object> result = aggregationExecutor.getResult();
for (int i = 0; i < result.size(); i++) {
diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkCombineGroupBy.java b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkCombineGroupBy.java
index 954db47..05b2e1b 100644
--- a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkCombineGroupBy.java
+++ b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkCombineGroupBy.java
@@ -19,8 +19,9 @@
package org.apache.pinot.perf;
import com.google.common.base.Joiner;
-import com.google.common.collect.Lists;
import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -37,7 +38,6 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.lang3.RandomStringUtils;
-import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.GroupBy;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
@@ -45,7 +45,8 @@ import org.apache.pinot.core.data.table.ConcurrentIndexedTable;
import org.apache.pinot.core.data.table.IndexedTable;
import org.apache.pinot.core.data.table.Record;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
-import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory;
+import org.apache.pinot.core.query.aggregation.function.MaxAggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.SumAggregationFunction;
import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByTrimmingService;
import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
import org.apache.pinot.core.query.utils.Pair;
@@ -77,11 +78,9 @@ public class BenchmarkCombineGroupBy {
private Random _random = new Random();
private DataSchema _dataSchema;
- private List<AggregationInfo> _aggregationInfos;
- private GroupBy _groupBy;
private AggregationFunction[] _aggregationFunctions;
+ private GroupBy _groupBy;
private List<SelectionSort> _orderBy;
- private int _numAggregationFunctions;
private List<String> _d1;
private List<Integer> _d2;
@@ -105,36 +104,19 @@ public class BenchmarkCombineGroupBy {
}
_dataSchema = new DataSchema(new String[]{"d1", "d2", "sum(m1)", "max(m2)"},
- new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
- DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE});
-
- AggregationInfo agg1 = new AggregationInfo();
- List<String> args1 = new ArrayList<>(1);
- args1.add("m1");
- agg1.setExpressions(args1);
- agg1.setAggregationType("sum");
-
- AggregationInfo agg2 = new AggregationInfo();
- List<String> args2 = new ArrayList<>(1);
- args2.add("m2");
- agg2.setExpressions(args2);
- agg2.setAggregationType("max");
- _aggregationInfos = Lists.newArrayList(agg1, agg2);
-
- _numAggregationFunctions = 2;
- _aggregationFunctions = new AggregationFunction[_numAggregationFunctions];
- for (int i = 0; i < _numAggregationFunctions; i++) {
- _aggregationFunctions[i] = AggregationFunctionFactory.getAggregationFunction(_aggregationInfos.get(i), null);
- }
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE});
+
+ _aggregationFunctions =
+ new AggregationFunction[]{new SumAggregationFunction("m1"), new MaxAggregationFunction("m2")};
_groupBy = new GroupBy();
_groupBy.setTopN(TOP_N);
- _groupBy.setExpressions(Lists.newArrayList("d1", "d2"));
+ _groupBy.setExpressions(Arrays.asList("d1", "d2"));
- SelectionSort orderBy = new SelectionSort();
- orderBy.setColumn("sum(m1)");
- orderBy.setIsAsc(true);
- _orderBy = Lists.newArrayList(orderBy);
+ SelectionSort selectionSort = new SelectionSort();
+ selectionSort.setColumn("sum(m1)");
+ selectionSort.setIsAsc(true);
+ _orderBy = Collections.singletonList(selectionSort);
_executorService = Executors.newFixedThreadPool(10);
}
@@ -161,13 +143,14 @@ public class BenchmarkCombineGroupBy {
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
- public void concurrentIndexedTableForCombineGroupBy() throws InterruptedException, ExecutionException, TimeoutException {
+ public void concurrentIndexedTableForCombineGroupBy()
+ throws InterruptedException, ExecutionException, TimeoutException {
int capacity = GroupByUtils.getTableCapacity(_groupBy, _orderBy);
// make 1 concurrent table
IndexedTable concurrentIndexedTable =
- new ConcurrentIndexedTable(_dataSchema, _aggregationInfos, _orderBy, capacity);
+ new ConcurrentIndexedTable(_dataSchema, _aggregationFunctions, _orderBy, capacity);
List<Callable<Void>> innerSegmentCallables = new ArrayList<>(NUM_SEGMENTS);
@@ -193,11 +176,11 @@ public class BenchmarkCombineGroupBy {
concurrentIndexedTable.finish(false);
}
-
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
- public void originalCombineGroupBy() throws InterruptedException, TimeoutException, ExecutionException {
+ public void originalCombineGroupBy()
+ throws InterruptedException, TimeoutException, ExecutionException {
AtomicInteger numGroups = new AtomicInteger();
int _interSegmentNumGroupsLimit = 200_000;
@@ -213,15 +196,14 @@ public class BenchmarkCombineGroupBy {
final Object[] value = newRecordOriginal.getSecond();
resultsMap.compute(stringKey, (k, v) -> {
+ int numAggregationFunctions = _aggregationFunctions.length;
if (v == null) {
if (numGroups.getAndIncrement() < _interSegmentNumGroupsLimit) {
- v = new Object[_numAggregationFunctions];
- for (int j = 0; j < _numAggregationFunctions; j++) {
- v[j] = value[j];
- }
+ v = new Object[numAggregationFunctions];
+ System.arraycopy(value, 0, v, 0, numAggregationFunctions);
}
} else {
- for (int j = 0; j < _numAggregationFunctions; j++) {
+ for (int j = 0; j < numAggregationFunctions; j++) {
v[j] = _aggregationFunctions[j].merge(v[j], value[j]);
}
}
@@ -243,13 +225,11 @@ public class BenchmarkCombineGroupBy {
List<Map<String, Object>> trimmedResults = aggregationGroupByTrimmingService.trimIntermediateResultsMap(resultsMap);
}
- public static void main(String[] args) throws Exception {
- ChainedOptionsBuilder opt = new OptionsBuilder().include(BenchmarkCombineGroupBy.class.getSimpleName())
- .warmupTime(TimeValue.seconds(10))
- .warmupIterations(1)
- .measurementTime(TimeValue.seconds(30))
- .measurementIterations(3)
- .forks(1);
+ public static void main(String[] args)
+ throws Exception {
+ ChainedOptionsBuilder opt =
+ new OptionsBuilder().include(BenchmarkCombineGroupBy.class.getSimpleName()).warmupTime(TimeValue.seconds(10))
+ .warmupIterations(1).measurementTime(TimeValue.seconds(30)).measurementIterations(3).forks(1);
new Runner(opt.build()).run();
}
diff --git a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkIndexedTable.java b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkIndexedTable.java
index 7312b6b..0b17edb 100644
--- a/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkIndexedTable.java
+++ b/pinot-perf/src/main/java/org/apache/pinot/perf/BenchmarkIndexedTable.java
@@ -18,8 +18,8 @@
*/
package org.apache.pinot.perf;
-import com.google.common.collect.Lists;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
@@ -33,13 +33,15 @@ import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.commons.lang3.RandomStringUtils;
-import org.apache.pinot.common.request.AggregationInfo;
import org.apache.pinot.common.request.SelectionSort;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.data.table.ConcurrentIndexedTable;
import org.apache.pinot.core.data.table.IndexedTable;
import org.apache.pinot.core.data.table.Record;
import org.apache.pinot.core.data.table.SimpleIndexedTable;
+import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.MaxAggregationFunction;
+import org.apache.pinot.core.query.aggregation.function.SumAggregationFunction;
import org.apache.pinot.core.util.trace.TraceRunnable;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
@@ -63,7 +65,7 @@ public class BenchmarkIndexedTable {
private Random _random = new Random();
private DataSchema _dataSchema;
- private List<AggregationInfo> _aggregationInfos;
+ private AggregationFunction[] _aggregationFunctions;
private List<SelectionSort> _orderBy;
private List<String> _d1;
@@ -71,7 +73,6 @@ public class BenchmarkIndexedTable {
private ExecutorService _executorService;
-
@Setup
public void setup() {
// create data
@@ -90,26 +91,15 @@ public class BenchmarkIndexedTable {
}
_dataSchema = new DataSchema(new String[]{"d1", "d2", "sum(m1)", "max(m2)"},
- new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT,
- DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE});
-
- AggregationInfo agg1 = new AggregationInfo();
- List<String> args1 = new ArrayList<>();
- args1.add("m1");
- agg1.setExpressions(args1);
- agg1.setAggregationType("sum");
-
- AggregationInfo agg2 = new AggregationInfo();
- List<String> args2 = new ArrayList<>();
- args2.add("m2");
- agg2.setExpressions(args2);
- agg2.setAggregationType("max");
- _aggregationInfos = Lists.newArrayList(agg1, agg2);
-
- SelectionSort orderBy = new SelectionSort();
- orderBy.setColumn("sum(m1)");
- orderBy.setIsAsc(true);
- _orderBy = Lists.newArrayList(orderBy);
+ new DataSchema.ColumnDataType[]{DataSchema.ColumnDataType.STRING, DataSchema.ColumnDataType.INT, DataSchema.ColumnDataType.DOUBLE, DataSchema.ColumnDataType.DOUBLE});
+
+ _aggregationFunctions =
+ new AggregationFunction[]{new SumAggregationFunction("m1"), new MaxAggregationFunction("m2")};
+
+ SelectionSort selectionSort = new SelectionSort();
+ selectionSort.setColumn("sum(m1)");
+ selectionSort.setIsAsc(true);
+ _orderBy = Collections.singletonList(selectionSort);
_executorService = Executors.newFixedThreadPool(10);
}
@@ -129,13 +119,14 @@ public class BenchmarkIndexedTable {
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
- public void concurrentIndexedTable() throws InterruptedException, ExecutionException, TimeoutException {
+ public void concurrentIndexedTable()
+ throws InterruptedException, ExecutionException, TimeoutException {
int numSegments = 10;
// make 1 concurrent table
IndexedTable concurrentIndexedTable =
- new ConcurrentIndexedTable(_dataSchema, _aggregationInfos, _orderBy, CAPACITY);
+ new ConcurrentIndexedTable(_dataSchema, _aggregationFunctions, _orderBy, CAPACITY);
// 10 parallel threads putting 10k records into the table
@@ -172,11 +163,11 @@ public class BenchmarkIndexedTable {
}
}
-
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
- public void simpleIndexedTable() throws InterruptedException, TimeoutException, ExecutionException {
+ public void simpleIndexedTable()
+ throws InterruptedException, TimeoutException, ExecutionException {
int numSegments = 10;
@@ -186,7 +177,7 @@ public class BenchmarkIndexedTable {
for (int i = 0; i < numSegments; i++) {
// make 10 indexed tables
- IndexedTable simpleIndexedTable = new SimpleIndexedTable(_dataSchema, _aggregationInfos, _orderBy, CAPACITY);
+ IndexedTable simpleIndexedTable = new SimpleIndexedTable(_dataSchema, _aggregationFunctions, _orderBy, CAPACITY);
simpleIndexedTables.add(simpleIndexedTable);
// put 10k records in each indexed table, in parallel
@@ -217,13 +208,11 @@ public class BenchmarkIndexedTable {
mergedTable.finish(false);
}
- public static void main(String[] args) throws Exception {
- ChainedOptionsBuilder opt = new OptionsBuilder().include(BenchmarkIndexedTable.class.getSimpleName())
- .warmupTime(TimeValue.seconds(10))
- .warmupIterations(1)
- .measurementTime(TimeValue.seconds(30))
- .measurementIterations(3)
- .forks(1);
+ public static void main(String[] args)
+ throws Exception {
+ ChainedOptionsBuilder opt =
+ new OptionsBuilder().include(BenchmarkIndexedTable.class.getSimpleName()).warmupTime(TimeValue.seconds(10))
+ .warmupIterations(1).measurementTime(TimeValue.seconds(30)).measurementIterations(3).forks(1);
new Runner(opt.build()).run();
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org