You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ro...@apache.org on 2023/07/28 20:36:05 UTC
[pinot] branch master updated: [multistage][agg] support agg with filter (#11144)
This is an automated email from the ASF dual-hosted git repository.
rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 05989559e0 [multistage][agg] support agg with filter (#11144)
05989559e0 is described below
commit 05989559e06d448d23d55147672345ede624381d
Author: Rong Rong <ro...@apache.org>
AuthorDate: Fri Jul 28 13:35:59 2023 -0700
[multistage][agg] support agg with filter (#11144)
* [init][agg] support agg filter where clause
* limitation still applies for nullable vs non-nullable results and agg filter merging with select filter, will be address in follow ups.
---------
Co-authored-by: Rong Rong <ro...@startree.ai>
---
.../pinot/common/datablock/DataBlockUtils.java | 311 +++++++++++++++++++++
.../PinotAggregateExchangeNodeInsertRule.java | 2 +-
.../query/parser/CalciteRexExpressionParser.java | 30 +-
.../query/planner/plannode/AggregateNode.java | 7 +
.../query/runtime/operator/AggregateOperator.java | 75 +++--
.../operator/MultistageAggregationExecutor.java | 27 +-
.../operator/MultistageGroupByExecutor.java | 83 +++++-
.../runtime/operator/block/DataBlockValSet.java | 15 +-
.../operator/block/FilteredDataBlockValSet.java | 33 +--
.../query/runtime/plan/PhysicalPlanVisitor.java | 2 +-
.../plan/server/ServerPlanRequestVisitor.java | 6 +-
.../runtime/operator/AggregateOperatorTest.java | 16 +-
.../test/resources/queries/FilterAggregates.json | 166 +++++++++++
13 files changed, 679 insertions(+), 94 deletions(-)
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java
index 99c3e4df46..b5d8321280 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java
@@ -23,6 +23,7 @@ import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.sql.Timestamp;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -560,4 +561,314 @@ public final class DataBlockUtils {
return rows;
}
+
+ /**
+ * Given a datablock and the column index, extracts the integer values for the column. Prefer using this function over
+ * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the
+ * desired type.
+ * This only works on ROW format.
+ * TODO: Add support for COLUMNAR format.
+ * @return int array of values in the column
+ */
+ public static int[] extractIntValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) {
+ DataSchema dataSchema = dataBlock.getDataSchema();
+ DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
+
+ // Get null bitmap for the column.
+ RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex];
+ int numRows = dataBlock.getNumberOfRows();
+
+ int[] rows = new int[numRows];
+ int outRowId = 0;
+ for (int inRowId = 0; inRowId < numRows; inRowId++) {
+ if (dataBlock.getInt(inRowId, filterArgIdx) == 1) {
+ if (nullBitmap != null && nullBitmap.contains(inRowId)) {
+ outRowId++;
+ continue;
+ }
+ switch (columnDataTypes[columnIndex]) {
+ case INT:
+ case BOOLEAN:
+ rows[outRowId++] = dataBlock.getInt(inRowId, columnIndex);
+ break;
+ case LONG:
+ rows[outRowId++] = (int) dataBlock.getLong(inRowId, columnIndex);
+ break;
+ case FLOAT:
+ rows[outRowId++] = (int) dataBlock.getFloat(inRowId, columnIndex);
+ break;
+ case DOUBLE:
+ rows[outRowId++] = (int) dataBlock.getDouble(inRowId, columnIndex);
+ break;
+ case BIG_DECIMAL:
+ rows[outRowId++] = dataBlock.getBigDecimal(inRowId, columnIndex).intValue();
+ break;
+ default:
+ throw new IllegalStateException(
+ String.format("Unsupported data type: %s for column: %s", columnDataTypes[columnIndex], columnIndex));
+ }
+ }
+ }
+ return Arrays.copyOfRange(rows, 0, outRowId);
+ }
+
+ /**
+ * Given a datablock and the column index, extracts the long values for the column. Prefer using this function over
+ * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the
+ * desired type.
+ * This only works on ROW format.
+ * TODO: Add support for COLUMNAR format.
+ * @return long array of values in the column
+ */
+ public static long[] extractLongValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) {
+ DataSchema dataSchema = dataBlock.getDataSchema();
+ DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
+
+ // Get null bitmap for the column.
+ RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex];
+ int numRows = dataBlock.getNumberOfRows();
+
+ long[] rows = new long[numRows];
+ int outRowId = 0;
+ for (int inRowId = 0; inRowId < numRows; inRowId++) {
+ if (dataBlock.getInt(inRowId, filterArgIdx) == 1) {
+ if (nullBitmap != null && nullBitmap.contains(inRowId)) {
+ outRowId++;
+ continue;
+ }
+ switch (columnDataTypes[columnIndex]) {
+ case INT:
+ case BOOLEAN:
+ rows[outRowId++] = dataBlock.getInt(inRowId, columnIndex);
+ break;
+ case LONG:
+ rows[outRowId++] = dataBlock.getLong(inRowId, columnIndex);
+ break;
+ case FLOAT:
+ rows[outRowId++] = (long) dataBlock.getFloat(inRowId, columnIndex);
+ break;
+ case DOUBLE:
+ rows[outRowId++] = (long) dataBlock.getDouble(inRowId, columnIndex);
+ break;
+ case BIG_DECIMAL:
+ rows[outRowId++] = dataBlock.getBigDecimal(inRowId, columnIndex).longValue();
+ break;
+ default:
+ throw new IllegalStateException(
+ String.format("Unsupported data type: %s for column: %s", columnDataTypes[columnIndex], columnIndex));
+ }
+ }
+ }
+ return Arrays.copyOfRange(rows, 0, outRowId);
+ }
+
+ /**
+ * Given a datablock and the column index, extracts the float values for the column. Prefer using this function over
+ * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the
+ * desired type.
+ * This only works on ROW format.
+ * TODO: Add support for COLUMNAR format.
+ * @return float array of values in the column
+ */
+ public static float[] extractFloatValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) {
+ DataSchema dataSchema = dataBlock.getDataSchema();
+ DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
+
+ // Get null bitmap for the column.
+ RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex];
+ int numRows = dataBlock.getNumberOfRows();
+
+ float[] rows = new float[numRows];
+ int outRowId = 0;
+ for (int inRowId = 0; inRowId < numRows; inRowId++) {
+ if (dataBlock.getInt(inRowId, filterArgIdx) == 1) {
+ if (nullBitmap != null && nullBitmap.contains(inRowId)) {
+ outRowId++;
+ continue;
+ }
+ switch (columnDataTypes[columnIndex]) {
+ case INT:
+ case BOOLEAN:
+ rows[outRowId++] = dataBlock.getInt(inRowId, columnIndex);
+ break;
+ case LONG:
+ rows[outRowId++] = dataBlock.getLong(inRowId, columnIndex);
+ break;
+ case FLOAT:
+ rows[outRowId++] = dataBlock.getFloat(inRowId, columnIndex);
+ break;
+ case DOUBLE:
+ rows[outRowId++] = (float) dataBlock.getDouble(inRowId, columnIndex);
+ break;
+ case BIG_DECIMAL:
+ rows[outRowId++] = dataBlock.getBigDecimal(inRowId, columnIndex).floatValue();
+ break;
+ default:
+ throw new IllegalStateException(
+ String.format("Unsupported data type: %s for column: %s", columnDataTypes[columnIndex], columnIndex));
+ }
+ }
+ }
+ return Arrays.copyOfRange(rows, 0, outRowId);
+ }
+
+ /**
+ * Given a datablock and the column index, extracts the double values for the column. Prefer using this function over
+ * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the
+ * desired type.
+ * This only works on ROW format.
+ * TODO: Add support for COLUMNAR format.
+ * @return double array of values in the column
+ */
+ public static double[] extractDoubleValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) {
+ DataSchema dataSchema = dataBlock.getDataSchema();
+ DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
+
+ // Get null bitmap for the column.
+ RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex];
+ int numRows = dataBlock.getNumberOfRows();
+
+ double[] rows = new double[numRows];
+ int outRowId = 0;
+ for (int inRowId = 0; inRowId < numRows; inRowId++) {
+ if (dataBlock.getInt(inRowId, filterArgIdx) == 1) {
+ if (nullBitmap != null && nullBitmap.contains(inRowId)) {
+ outRowId++;
+ continue;
+ }
+ switch (columnDataTypes[columnIndex]) {
+ case INT:
+ case BOOLEAN:
+ rows[outRowId++] = dataBlock.getInt(inRowId, columnIndex);
+ break;
+ case LONG:
+ rows[outRowId++] = dataBlock.getLong(inRowId, columnIndex);
+ break;
+ case FLOAT:
+ rows[outRowId++] = dataBlock.getFloat(inRowId, columnIndex);
+ break;
+ case DOUBLE:
+ rows[outRowId++] = dataBlock.getDouble(inRowId, columnIndex);
+ break;
+ case BIG_DECIMAL:
+ rows[outRowId++] = dataBlock.getBigDecimal(inRowId, columnIndex).doubleValue();
+ break;
+ default:
+ throw new IllegalStateException(
+ String.format("Unsupported data type: %s for column: %s", columnDataTypes[columnIndex], columnIndex));
+ }
+ }
+ }
+ return Arrays.copyOfRange(rows, 0, outRowId);
+ }
+
+ /**
+ * Given a datablock and the column index, extracts the BigDecimal values for the column. Prefer using this function
+ * over extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to
+ * the desired type.
+ * This only works on ROW format.
+ * TODO: Add support for COLUMNAR format.
+ * @return BigDecimal array of values in the column
+ */
+ public static BigDecimal[] extractBigDecimalValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) {
+ DataSchema dataSchema = dataBlock.getDataSchema();
+ DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
+
+ // Get null bitmap for the column.
+ RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex];
+ int numRows = dataBlock.getNumberOfRows();
+
+ BigDecimal[] rows = new BigDecimal[numRows];
+ int outRowId = 0;
+ for (int inRowId = 0; inRowId < numRows; inRowId++) {
+ if (dataBlock.getInt(inRowId, filterArgIdx) == 1) {
+ if (nullBitmap != null && nullBitmap.contains(inRowId)) {
+ outRowId++;
+ continue;
+ }
+ switch (columnDataTypes[columnIndex]) {
+ case INT:
+ case BOOLEAN:
+ rows[outRowId++] = BigDecimal.valueOf(dataBlock.getInt(inRowId, columnIndex));
+ break;
+ case LONG:
+ rows[outRowId++] = BigDecimal.valueOf(dataBlock.getLong(inRowId, columnIndex));
+ break;
+ case FLOAT:
+ rows[outRowId++] = BigDecimal.valueOf(dataBlock.getFloat(inRowId, columnIndex));
+ break;
+ case DOUBLE:
+ rows[outRowId++] = BigDecimal.valueOf(dataBlock.getDouble(inRowId, columnIndex));
+ break;
+ case BIG_DECIMAL:
+ rows[outRowId++] = BigDecimal.valueOf(dataBlock.getBigDecimal(inRowId, columnIndex).doubleValue());
+ break;
+ default:
+ throw new IllegalStateException(
+ String.format("Unsupported data type: %s for column: %s", columnDataTypes[columnIndex], columnIndex));
+ }
+ }
+ }
+ return Arrays.copyOfRange(rows, 0, outRowId);
+ }
+
+ /**
+ * Given a datablock and the column index, extracts the String values for the column. Prefer using this function over
+ * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the
+ * desired type.
+ * This only works on ROW format.
+ * TODO: Add support for COLUMNAR format.
+ * @return String array of values in the column
+ */
+ public static String[] extractStringValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) {
+ DataSchema dataSchema = dataBlock.getDataSchema();
+ DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
+
+ // Get null bitmap for the column.
+ RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex];
+ int numRows = dataBlock.getNumberOfRows();
+
+ String[] rows = new String[numRows];
+ int outRowId = 0;
+ for (int inRowId = 0; inRowId < numRows; inRowId++) {
+ if (dataBlock.getInt(inRowId, filterArgIdx) == 1) {
+ if (nullBitmap != null && nullBitmap.contains(inRowId)) {
+ outRowId++;
+ continue;
+ }
+ rows[outRowId++] = dataBlock.getString(inRowId, columnIndex);
+ }
+ }
+ return Arrays.copyOfRange(rows, 0, outRowId);
+ }
+
+ /**
+ * Given a datablock and the column index, extracts the byte values for the column. Prefer using this function over
+ * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the
+ * desired type.
+ * This only works on ROW format.
+ * TODO: Add support for COLUMNAR format.
+ * @return byte array of values in the column
+ */
+ public static byte[][] extractBytesValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) {
+ DataSchema dataSchema = dataBlock.getDataSchema();
+ DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
+
+ // Get null bitmap for the column.
+ RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex];
+ int numRows = dataBlock.getNumberOfRows();
+
+ byte[][] rows = new byte[numRows][];
+ int outRowId = 0;
+ for (int inRowId = 0; inRowId < numRows; inRowId++) {
+ if (dataBlock.getInt(inRowId, filterArgIdx) == 1) {
+ if (nullBitmap != null && nullBitmap.contains(inRowId)) {
+ outRowId++;
+ continue;
+ }
+ rows[outRowId++] = dataBlock.getBytes(inRowId, columnIndex).getBytes();
+ }
+ }
+ return Arrays.copyOfRange(rows, 0, outRowId);
+ }
}
diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
index 60efa69fde..df904123d2 100644
--- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
+++ b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
@@ -324,7 +324,7 @@ public class PinotAggregateExchangeNodeInsertRule extends RelOptRule {
orgAggCall.isApproximate(),
orgAggCall.ignoreNulls(),
argList,
- orgAggCall.filterArg,
+ aggType.isInputIntermediateFormat() ? -1 : orgAggCall.filterArg,
orgAggCall.distinctKeys,
orgAggCall.collation,
numberGroups,
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
index 632471f6b0..4c31fc86f4 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
@@ -19,6 +19,7 @@
package org.apache.pinot.query.parser;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
@@ -67,19 +68,30 @@ public class CalciteRexExpressionParser {
// Relational conversion Utils
// --------------------------------------------------------------------------
- public static List<Expression> overwriteSelectList(List<RexExpression> rexNodeList, PinotQuery pinotQuery) {
- return addSelectList(new ArrayList<>(), rexNodeList, pinotQuery);
- }
-
- public static List<Expression> addSelectList(List<Expression> existingList, List<RexExpression> rexNodeList,
- PinotQuery pinotQuery) {
- List<Expression> selectExpr = new ArrayList<>(existingList);
-
- final Iterator<RexExpression> iterator = rexNodeList.iterator();
+ public static List<Expression> convertProjectList(List<RexExpression> projectList, PinotQuery pinotQuery) {
+ List<Expression> selectExpr = new ArrayList<>();
+ final Iterator<RexExpression> iterator = projectList.iterator();
while (iterator.hasNext()) {
final RexExpression next = iterator.next();
selectExpr.add(toExpression(next, pinotQuery));
}
+ return selectExpr;
+ }
+
+ public static List<Expression> convertAggregateList(List<Expression> groupSetList, List<RexExpression> aggCallList,
+ List<Integer> filterArgIndices, PinotQuery pinotQuery) {
+ List<Expression> selectExpr = new ArrayList<>(groupSetList);
+
+ for (int idx = 0; idx < aggCallList.size(); idx++) {
+ final RexExpression aggCall = aggCallList.get(idx);
+ int filterArgIdx = filterArgIndices.get(idx);
+ if (filterArgIdx == -1) {
+ selectExpr.add(toExpression(aggCall, pinotQuery));
+ } else {
+ selectExpr.add(toExpression(new RexExpression.FunctionCall(SqlKind.FILTER, aggCall.getDataType(), "FILTER",
+ Arrays.asList(aggCall, new RexExpression.InputRef(filterArgIdx))), pinotQuery));
+ }
+ }
return selectExpr;
}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java
index c465fe93f6..5c8a0999c0 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java
@@ -36,6 +36,8 @@ public class AggregateNode extends AbstractPlanNode {
@ProtoProperties
private List<RexExpression> _aggCalls;
@ProtoProperties
+ private List<Integer> _filterArgIndices;
+ @ProtoProperties
private List<RexExpression> _groupSet;
@ProtoProperties
private AggType _aggType;
@@ -49,6 +51,7 @@ public class AggregateNode extends AbstractPlanNode {
super(planFragmentId, dataSchema);
Preconditions.checkState(areHintsValid(relHints), "invalid sql hint for agg node: {}", relHints);
_aggCalls = aggCalls.stream().map(RexExpression::toRexExpression).collect(Collectors.toList());
+ _filterArgIndices = aggCalls.stream().map(c -> c.filterArg).collect(Collectors.toList());
_groupSet = groupSet;
_nodeHint = new NodeHint(relHints);
_aggType = AggType.valueOf(PinotHintStrategyTable.getHintOption(relHints, PinotHintOptions.INTERNAL_AGG_OPTIONS,
@@ -63,6 +66,10 @@ public class AggregateNode extends AbstractPlanNode {
return _aggCalls;
}
+ public List<Integer> getFilterArgIndices() {
+ return _filterArgIndices;
+ }
+
public List<RexExpression> getGroupSet() {
return _groupSet;
}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 494ad1d737..22814b4571 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -35,7 +35,6 @@ import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FunctionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
-import org.apache.pinot.core.common.IntermediateStageBlockValSet;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory;
import org.apache.pinot.query.planner.logical.LiteralHintUtils;
@@ -44,7 +43,10 @@ import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
import org.apache.pinot.query.planner.plannode.AggregateNode.AggType;
import org.apache.pinot.query.runtime.blocks.TransferableBlock;
import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
+import org.apache.pinot.query.runtime.operator.block.DataBlockValSet;
+import org.apache.pinot.query.runtime.operator.block.FilteredDataBlockValSet;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
+import org.apache.pinot.spi.data.FieldSpec;
/**
@@ -90,31 +92,28 @@ public class AggregateOperator extends MultiStageOperator {
private MultistageAggregationExecutor _aggregationExecutor;
private MultistageGroupByExecutor _groupByExecutor;
- // TODO: refactor Pinot Reducer code to support the intermediate stage agg operator.
- // aggCalls has to be a list of FunctionCall and cannot be null
- // groupSet has to be a list of InputRef and cannot be null
- // TODO: Add these two checks when we confirm we can handle error in upstream ctor call.
- public AggregateOperator(OpChainExecutionContext context, MultiStageOperator inputOperator,
- DataSchema resultSchema, DataSchema inputSchema, List<RexExpression> aggCalls, List<RexExpression> groupSet,
- AggType aggType) {
- this(context, inputOperator, resultSchema, inputSchema, aggCalls, groupSet, aggType,
- new AbstractPlanNode.NodeHint());
- }
-
@VisibleForTesting
- public AggregateOperator(OpChainExecutionContext context, MultiStageOperator inputOperator,
- DataSchema resultSchema, DataSchema inputSchema, List<RexExpression> aggCalls, List<RexExpression> groupSet,
- AggType aggType, AbstractPlanNode.NodeHint nodeHint) {
+ public AggregateOperator(OpChainExecutionContext context, MultiStageOperator inputOperator, DataSchema resultSchema,
+ DataSchema inputSchema, List<RexExpression> aggCalls, List<RexExpression> groupSet, AggType aggType,
+ @Nullable List<Integer> filterArgIndices, @Nullable AbstractPlanNode.NodeHint nodeHint) {
super(context);
_inputOperator = inputOperator;
_resultSchema = resultSchema;
_inputSchema = inputSchema;
_aggType = aggType;
+ // filter arg index array
+ int[] filterArgIndexArray;
+ if (filterArgIndices == null || filterArgIndices.size() == 0) {
+ filterArgIndexArray = null;
+ } else {
+ filterArgIndexArray = filterArgIndices.stream().mapToInt(Integer::intValue).toArray();
+ }
+ // filter operand and literal hints
if (nodeHint != null && nodeHint._hintOptions != null
&& nodeHint._hintOptions.get(PinotHintOptions.INTERNAL_AGG_OPTIONS) != null) {
- _aggCallSignatureMap = LiteralHintUtils.hintStringToLiteralMap(nodeHint._hintOptions
- .get(PinotHintOptions.INTERNAL_AGG_OPTIONS)
- .get(PinotHintOptions.InternalAggregateOptions.AGG_CALL_SIGNATURE));
+ _aggCallSignatureMap = LiteralHintUtils.hintStringToLiteralMap(
+ nodeHint._hintOptions.get(PinotHintOptions.INTERNAL_AGG_OPTIONS)
+ .get(PinotHintOptions.InternalAggregateOptions.AGG_CALL_SIGNATURE));
} else {
_aggCallSignatureMap = Collections.emptyMap();
}
@@ -126,6 +125,7 @@ public class AggregateOperator extends MultiStageOperator {
// Convert groupSet to ExpressionContext that our aggregation functions understand.
List<ExpressionContext> groupByExpr = getGroupSet(groupSet);
+
List<FunctionContext> functionContexts = getFunctionContexts(aggCalls);
AggregationFunction[] aggFunctions = new AggregationFunction[functionContexts.size()];
@@ -136,12 +136,14 @@ public class AggregateOperator extends MultiStageOperator {
// Initialize the appropriate executor.
if (!groupSet.isEmpty()) {
_isGroupByAggregation = true;
- _groupByExecutor = new MultistageGroupByExecutor(groupByExpr, aggFunctions, aggType, _colNameToIndexMap,
- _resultSchema);
+ _groupByExecutor =
+ new MultistageGroupByExecutor(groupByExpr, aggFunctions, filterArgIndexArray, aggType, _colNameToIndexMap,
+ _resultSchema);
} else {
_isGroupByAggregation = false;
- _aggregationExecutor = new MultistageAggregationExecutor(aggFunctions, aggType, _colNameToIndexMap,
- _resultSchema);
+ _aggregationExecutor =
+ new MultistageAggregationExecutor(aggFunctions, filterArgIndexArray, aggType, _colNameToIndexMap,
+ _resultSchema);
}
}
@@ -253,7 +255,7 @@ public class AggregateOperator extends MultiStageOperator {
// The literal value here does not matter. We create a dummy literal here just so that the count aggregation
// has some column to process.
if (aggArguments.isEmpty()) {
- aggArguments.add(ExpressionContext.forIdentifier("__PLACEHOLDER__"));
+ aggArguments.add(ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "__PLACEHOLDER__"));
}
return new FunctionContext(FunctionContext.Type.AGGREGATION, functionName, aggArguments);
@@ -320,8 +322,8 @@ public class AggregateOperator extends MultiStageOperator {
// TODO: If the previous block is not mailbox received, this method is not efficient. Then getDataBlock() will
// convert the unserialized format to serialized format of BaseDataBlock. Then it will convert it back to column
// value primitive type.
- static Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggFunction,
- TransferableBlock block, DataSchema inputDataSchema, Map<String, Integer> colNameToIndexMap) {
+ static Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggFunction, TransferableBlock block,
+ DataSchema inputDataSchema, Map<String, Integer> colNameToIndexMap, int filterArgIdx) {
List<ExpressionContext> expressions = aggFunction.getInputExpressions();
int numExpressions = expressions.size();
if (numExpressions == 0) {
@@ -330,17 +332,34 @@ public class AggregateOperator extends MultiStageOperator {
Map<ExpressionContext, BlockValSet> blockValSetMap = new HashMap<>();
for (ExpressionContext expression : expressions) {
- if (expression.getType().equals(ExpressionContext.Type.IDENTIFIER)
- && !"__PLACEHOLDER__".equals(expression.getIdentifier())) {
+ if (expression.getType().equals(ExpressionContext.Type.IDENTIFIER) && !"__PLACEHOLDER__".equals(
+ expression.getIdentifier())) {
int index = colNameToIndexMap.get(expression.getIdentifier());
DataSchema.ColumnDataType dataType = inputDataSchema.getColumnDataType(index);
Preconditions.checkState(block.getType().equals(DataBlock.Type.ROW), "Datablock type is not ROW");
- blockValSetMap.put(expression, new IntermediateStageBlockValSet(dataType, block.getDataBlock(), index));
+ if (filterArgIdx == -1) {
+ blockValSetMap.put(expression, new DataBlockValSet(dataType, block.getDataBlock(), index));
+ } else {
+ blockValSetMap.put(expression,
+ new FilteredDataBlockValSet(dataType, block.getDataBlock(), index, filterArgIdx));
+ }
}
}
return blockValSetMap;
}
+ static int computeBlockNumRows(TransferableBlock block, int filterArgIdx) {
+ if (filterArgIdx == -1) {
+ return block.getNumRows();
+ } else {
+ int rowCount = 0;
+ for (int rowId = 0; rowId < block.getNumRows(); rowId++) {
+ rowCount += block.getDataBlock().getInt(rowId, filterArgIdx) == 1 ? 1 : 0;
+ }
+ return rowCount;
+ }
+ }
+
static Object extractValueFromRow(AggregationFunction aggregationFunction, Object[] row,
Map<String, Integer> colNameToIndexMap) {
List<ExpressionContext> expressions = aggregationFunction.getInputExpressions();
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java
index 19e4f66cc6..b352997ac1 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java
@@ -21,6 +21,7 @@ package org.apache.pinot.query.runtime.operator;
import java.util.Collections;
import java.util.List;
import java.util.Map;
+import javax.annotation.Nullable;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
@@ -42,13 +43,15 @@ public class MultistageAggregationExecutor {
private final DataSchema _resultSchema;
private final AggregationFunction[] _aggFunctions;
+ private final int[] _filterArgIndices;
// Result holders for each mode.
private final AggregationResultHolder[] _aggregateResultHolder;
private final Object[] _mergeResultHolder;
- public MultistageAggregationExecutor(AggregationFunction[] aggFunctions,
+ public MultistageAggregationExecutor(AggregationFunction[] aggFunctions, @Nullable int[] filterArgIndices,
AggType aggType, Map<String, Integer> colNameToIndexMap, DataSchema resultSchema) {
+ _filterArgIndices = filterArgIndices;
_aggFunctions = aggFunctions;
_aggType = aggType;
_colNameToIndexMap = colNameToIndexMap;
@@ -116,11 +119,23 @@ public class MultistageAggregationExecutor {
}
private void processAggregate(TransferableBlock block, DataSchema inputDataSchema) {
- for (int i = 0; i < _aggFunctions.length; i++) {
- AggregationFunction aggregationFunction = _aggFunctions[i];
- Map<ExpressionContext, BlockValSet> blockValSetMap =
- AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap);
- aggregationFunction.aggregate(block.getNumRows(), _aggregateResultHolder[i], blockValSetMap);
+ if (_filterArgIndices == null) {
+ for (int i = 0; i < _aggFunctions.length; i++) {
+ AggregationFunction aggregationFunction = _aggFunctions[i];
+ Map<ExpressionContext, BlockValSet> blockValSetMap =
+ AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap, -1);
+ aggregationFunction.aggregate(block.getNumRows(), _aggregateResultHolder[i], blockValSetMap);
+ }
+ } else {
+ for (int i = 0; i < _aggFunctions.length; i++) {
+ AggregationFunction aggregationFunction = _aggFunctions[i];
+ int filterArgIdx = _filterArgIndices[i];
+ Map<ExpressionContext, BlockValSet> blockValSetMap =
+ AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap,
+ filterArgIdx);
+ int numRows = AggregateOperator.computeBlockNumRows(block, filterArgIdx);
+ aggregationFunction.aggregate(numRows, _aggregateResultHolder[i], blockValSetMap);
+ }
}
}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
index 5eacba025b..e33cc491cf 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
@@ -19,9 +19,11 @@
package org.apache.pinot.query.runtime.operator;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import javax.annotation.Nullable;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
@@ -47,6 +49,7 @@ public class MultistageGroupByExecutor {
private final List<ExpressionContext> _groupSet;
private final AggregationFunction[] _aggFunctions;
+ private final int[] _filterArgIndices;
// Group By Result holders for each mode
private final GroupByResultHolder[] _aggregateResultHolders;
@@ -57,11 +60,13 @@ public class MultistageGroupByExecutor {
private final Map<Key, Integer> _groupKeyToIdMap;
public MultistageGroupByExecutor(List<ExpressionContext> groupByExpr, AggregationFunction[] aggFunctions,
- AggType aggType, Map<String, Integer> colNameToIndexMap, DataSchema resultSchema) {
+ @Nullable int[] filterArgIndices, AggType aggType, Map<String, Integer> colNameToIndexMap,
+ DataSchema resultSchema) {
_aggType = aggType;
_colNameToIndexMap = colNameToIndexMap;
_groupSet = groupByExpr;
_aggFunctions = aggFunctions;
+ _filterArgIndices = filterArgIndices;
_resultSchema = resultSchema;
_aggregateResultHolders = new GroupByResultHolder[_aggFunctions.length];
@@ -70,9 +75,9 @@ public class MultistageGroupByExecutor {
_groupKeyToIdMap = new HashMap<>();
for (int i = 0; i < _aggFunctions.length; i++) {
- _aggregateResultHolders[i] =
- _aggFunctions[i].createGroupByResultHolder(InstancePlanMakerImplV2.DEFAULT_MAX_INITIAL_RESULT_HOLDER_CAPACITY,
- InstancePlanMakerImplV2.DEFAULT_NUM_GROUPS_LIMIT);
+ _aggregateResultHolders[i] = _aggFunctions[i].createGroupByResultHolder(
+ InstancePlanMakerImplV2.DEFAULT_MAX_INITIAL_RESULT_HOLDER_CAPACITY,
+ InstancePlanMakerImplV2.DEFAULT_NUM_GROUPS_LIMIT);
}
}
@@ -129,15 +134,29 @@ public class MultistageGroupByExecutor {
}
private void processAggregate(TransferableBlock block, DataSchema inputDataSchema) {
- int[] intKeys = generateGroupByKeys(block.getContainer());
-
- for (int i = 0; i < _aggFunctions.length; i++) {
- AggregationFunction aggregationFunction = _aggFunctions[i];
- Map<ExpressionContext, BlockValSet> blockValSetMap =
- AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap);
- GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
- groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size());
- aggregationFunction.aggregateGroupBySV(block.getNumRows(), intKeys, groupByResultHolder, blockValSetMap);
+ if (_filterArgIndices == null) {
+ int[] intKeys = generateGroupByKeys(block.getContainer());
+ for (int i = 0; i < _aggFunctions.length; i++) {
+ AggregationFunction aggregationFunction = _aggFunctions[i];
+ Map<ExpressionContext, BlockValSet> blockValSetMap =
+ AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap, -1);
+ GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
+ groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size());
+ aggregationFunction.aggregateGroupBySV(block.getNumRows(), intKeys, groupByResultHolder, blockValSetMap);
+ }
+ } else {
+ for (int i = 0; i < _aggFunctions.length; i++) {
+ AggregationFunction aggregationFunction = _aggFunctions[i];
+ int filterArgIdx = _filterArgIndices[i];
+ int[] intKeys = generateGroupByKeys(block.getContainer(), filterArgIdx);
+ Map<ExpressionContext, BlockValSet> blockValSetMap =
+ AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap,
+ filterArgIdx);
+ int numRows = AggregateOperator.computeBlockNumRows(block, filterArgIdx);
+ GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
+ groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size());
+ aggregationFunction.aggregateGroupBySV(numRows, intKeys, groupByResultHolder, blockValSetMap);
+ }
}
}
@@ -191,4 +210,42 @@ public class MultistageGroupByExecutor {
}
return rowIntKeys;
}
+
+ /**
+ * Creates the group by key for each row. Converts the key into a 0-index based int value that can be used by
+ * GroupByAggregationResultHolders used in v1 aggregations.
+ * <p>
+ * Returns the int key for each row.
+ */
+ private int[] generateGroupByKeys(List<Object[]> rows, int filterArgIndex) {
+ int numRows = rows.size();
+ int[] rowIntKeys = new int[numRows];
+ int numKeys = _groupSet.size();
+ if (filterArgIndex == -1) {
+ for (int rowId = 0; rowId < numRows; rowId++) {
+ Object[] row = rows.get(rowId);
+ Object[] keyValues = new Object[numKeys];
+ for (int j = 0; j < numKeys; j++) {
+ keyValues[j] = row[_colNameToIndexMap.get(_groupSet.get(j).getIdentifier())];
+ }
+ Key rowKey = new Key(keyValues);
+ rowIntKeys[rowId] = _groupKeyToIdMap.computeIfAbsent(rowKey, k -> _groupKeyToIdMap.size());
+ }
+ return rowIntKeys;
+ } else {
+ int outRowId = 0;
+ for (int inRowId = 0; inRowId < numRows; inRowId++) {
+ Object[] row = rows.get(inRowId);
+ if ((Boolean) row[filterArgIndex]) {
+ Object[] keyValues = new Object[numKeys];
+ for (int j = 0; j < numKeys; j++) {
+ keyValues[j] = row[_colNameToIndexMap.get(_groupSet.get(j).getIdentifier())];
+ }
+ Key rowKey = new Key(keyValues);
+ rowIntKeys[outRowId++] = _groupKeyToIdMap.computeIfAbsent(rowKey, k -> _groupKeyToIdMap.size());
+ }
+ }
+ return Arrays.copyOfRange(rowIntKeys, 0, outRowId);
+ }
+ }
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/DataBlockValSet.java
similarity index 90%
copy from pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java
copy to pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/DataBlockValSet.java
index 7ddaaf04c9..e1bbc077ef 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/DataBlockValSet.java
@@ -16,13 +16,14 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.pinot.core.common;
+package org.apache.pinot.query.runtime.operator.block;
import java.math.BigDecimal;
import javax.annotation.Nullable;
import org.apache.pinot.common.datablock.DataBlock;
import org.apache.pinot.common.datablock.DataBlockUtils;
import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.segment.spi.index.reader.Dictionary;
import org.apache.pinot.spi.data.FieldSpec;
import org.roaringbitmap.RoaringBitmap;
@@ -34,13 +35,13 @@ import org.roaringbitmap.RoaringBitmap;
* aggregations using v1 aggregation functions.
* TODO: Support MV
*/
-public class IntermediateStageBlockValSet implements BlockValSet {
- private final FieldSpec.DataType _dataType;
- private final DataBlock _dataBlock;
- private final int _index;
- private final RoaringBitmap _nullBitMap;
+public class DataBlockValSet implements BlockValSet {
+ protected final FieldSpec.DataType _dataType;
+ protected final DataBlock _dataBlock;
+ protected final int _index;
+ protected final RoaringBitmap _nullBitMap;
- public IntermediateStageBlockValSet(DataSchema.ColumnDataType columnDataType, DataBlock dataBlock, int colIndex) {
+ public DataBlockValSet(DataSchema.ColumnDataType columnDataType, DataBlock dataBlock, int colIndex) {
_dataType = columnDataType.toDataType();
_dataBlock = dataBlock;
_index = colIndex;
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/FilteredDataBlockValSet.java
similarity index 86%
rename from pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java
rename to pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/FilteredDataBlockValSet.java
index 7ddaaf04c9..e4231fbd71 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/FilteredDataBlockValSet.java
@@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.pinot.core.common;
+package org.apache.pinot.query.runtime.operator.block;
import java.math.BigDecimal;
import javax.annotation.Nullable;
@@ -27,6 +27,7 @@ import org.apache.pinot.segment.spi.index.reader.Dictionary;
import org.apache.pinot.spi.data.FieldSpec;
import org.roaringbitmap.RoaringBitmap;
+
/**
* In the multistage engine, the leaf stage servers process the data in columnar fashion. By the time the
* intermediate stage receives the projected column, they are converted to a row based format. This class provides
@@ -34,17 +35,13 @@ import org.roaringbitmap.RoaringBitmap;
* aggregations using v1 aggregation functions.
* TODO: Support MV
*/
-public class IntermediateStageBlockValSet implements BlockValSet {
- private final FieldSpec.DataType _dataType;
- private final DataBlock _dataBlock;
- private final int _index;
- private final RoaringBitmap _nullBitMap;
+public class FilteredDataBlockValSet extends DataBlockValSet {
+ private final int _filterIdx;
- public IntermediateStageBlockValSet(DataSchema.ColumnDataType columnDataType, DataBlock dataBlock, int colIndex) {
- _dataType = columnDataType.toDataType();
- _dataBlock = dataBlock;
- _index = colIndex;
- _nullBitMap = dataBlock.getNullRowIds(colIndex);
+ public FilteredDataBlockValSet(DataSchema.ColumnDataType columnDataType, DataBlock dataBlock, int colIndex,
+ int filterIdx) {
+ super(columnDataType, dataBlock, colIndex);
+ _filterIdx = filterIdx;
}
/**
@@ -80,37 +77,37 @@ public class IntermediateStageBlockValSet implements BlockValSet {
@Override
public int[] getIntValuesSV() {
- return DataBlockUtils.extractIntValuesForColumn(_dataBlock, _index);
+ return DataBlockUtils.extractIntValuesForColumn(_dataBlock, _index, _filterIdx);
}
@Override
public long[] getLongValuesSV() {
- return DataBlockUtils.extractLongValuesForColumn(_dataBlock, _index);
+ return DataBlockUtils.extractLongValuesForColumn(_dataBlock, _index, _filterIdx);
}
@Override
public float[] getFloatValuesSV() {
- return DataBlockUtils.extractFloatValuesForColumn(_dataBlock, _index);
+ return DataBlockUtils.extractFloatValuesForColumn(_dataBlock, _index, _filterIdx);
}
@Override
public double[] getDoubleValuesSV() {
- return DataBlockUtils.extractDoubleValuesForColumn(_dataBlock, _index);
+ return DataBlockUtils.extractDoubleValuesForColumn(_dataBlock, _index, _filterIdx);
}
@Override
public BigDecimal[] getBigDecimalValuesSV() {
- return DataBlockUtils.extractBigDecimalValuesForColumn(_dataBlock, _index);
+ return DataBlockUtils.extractBigDecimalValuesForColumn(_dataBlock, _index, _filterIdx);
}
@Override
public String[] getStringValuesSV() {
- return DataBlockUtils.extractStringValuesForColumn(_dataBlock, _index);
+ return DataBlockUtils.extractStringValuesForColumn(_dataBlock, _index, _filterIdx);
}
@Override
public byte[][] getBytesValuesSV() {
- return DataBlockUtils.extractBytesValuesForColumn(_dataBlock, _index);
+ return DataBlockUtils.extractBytesValuesForColumn(_dataBlock, _index, _filterIdx);
}
@Override
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
index 79f2275274..2340545437 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
@@ -101,7 +101,7 @@ public class PhysicalPlanVisitor implements PlanNodeVisitor<MultiStageOperator,
DataSchema resultSchema = node.getDataSchema();
return new AggregateOperator(context.getOpChainExecutionContext(), nextOperator, resultSchema, inputSchema,
- node.getAggCalls(), node.getGroupSet(), node.getAggType(), node.getNodeHint());
+ node.getAggCalls(), node.getGroupSet(), node.getAggType(), node.getFilterArgIndices(), node.getNodeHint());
}
@Override
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java
index 8fa6df2756..5e3873fbcd 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java
@@ -71,8 +71,8 @@ public class ServerPlanRequestVisitor implements PlanNodeVisitor<Void, ServerPla
.setGroupByList(CalciteRexExpressionParser.convertGroupByList(node.getGroupSet(), context.getPinotQuery()));
// set agg list
context.getPinotQuery().setSelectList(
- CalciteRexExpressionParser.addSelectList(context.getPinotQuery().getGroupByList(), node.getAggCalls(),
- context.getPinotQuery()));
+ CalciteRexExpressionParser.convertAggregateList(context.getPinotQuery().getGroupByList(), node.getAggCalls(),
+ node.getFilterArgIndices(), context.getPinotQuery()));
return null;
}
@@ -149,7 +149,7 @@ public class ServerPlanRequestVisitor implements PlanNodeVisitor<Void, ServerPla
public Void visitProject(ProjectNode node, ServerPlanRequestContext context) {
visitChildren(node, context);
context.getPinotQuery()
- .setSelectList(CalciteRexExpressionParser.overwriteSelectList(node.getProjects(), context.getPinotQuery()));
+ .setSelectList(CalciteRexExpressionParser.convertProjectList(node.getProjects(), context.getPinotQuery()));
return null;
}
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
index 8048bad5f4..d32d38b2b0 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
@@ -79,7 +79,7 @@ public class AggregateOperatorTest {
DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
AggregateOperator operator =
new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
- AggType.INTERMEDIATE);
+ AggType.INTERMEDIATE, null, null);
// When:
TransferableBlock block1 = operator.nextBlock(); // build
@@ -101,7 +101,7 @@ public class AggregateOperatorTest {
DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
AggregateOperator operator =
new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
- AggType.LEAF);
+ AggType.LEAF, null, null);
// When:
TransferableBlock block = operator.nextBlock();
@@ -125,7 +125,7 @@ public class AggregateOperatorTest {
DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
AggregateOperator operator =
new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
- AggType.LEAF);
+ AggType.LEAF, null, null);
// When:
TransferableBlock block1 = operator.nextBlock(); // build when reading NoOp block
@@ -150,7 +150,7 @@ public class AggregateOperatorTest {
DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
AggregateOperator operator =
new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
- AggType.INTERMEDIATE);
+ AggType.INTERMEDIATE, null, null);
// When:
TransferableBlock block1 = operator.nextBlock();
@@ -177,7 +177,7 @@ public class AggregateOperatorTest {
DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, LONG});
AggregateOperator operator =
new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
- AggType.FINAL);
+ AggType.FINAL, null, null);
// When:
TransferableBlock block1 = operator.nextBlock();
@@ -201,7 +201,7 @@ public class AggregateOperatorTest {
DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{STRING, DOUBLE});
AggregateOperator sum0GroupBy1 = new AggregateOperator(OperatorTestUtil.getDefaultContext(), upstreamOperator,
outSchema, inSchema, Collections.singletonList(agg),
- Collections.singletonList(new RexExpression.InputRef(1)), AggType.LEAF);
+ Collections.singletonList(new RexExpression.InputRef(1)), AggType.LEAF, null, null);
TransferableBlock result = sum0GroupBy1.getNextBlock();
while (result.isNoOpBlock()) {
result = sum0GroupBy1.getNextBlock();
@@ -229,7 +229,7 @@ public class AggregateOperatorTest {
// When:
AggregateOperator operator =
new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
- AggType.INTERMEDIATE);
+ AggType.INTERMEDIATE, null, null);
}
@Test
@@ -248,7 +248,7 @@ public class AggregateOperatorTest {
DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
AggregateOperator operator =
new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
- AggType.INTERMEDIATE);
+ AggType.INTERMEDIATE, null, null);
// When:
TransferableBlock block = operator.nextBlock();
diff --git a/pinot-query-runtime/src/test/resources/queries/FilterAggregates.json b/pinot-query-runtime/src/test/resources/queries/FilterAggregates.json
new file mode 100644
index 0000000000..af308031ab
--- /dev/null
+++ b/pinot-query-runtime/src/test/resources/queries/FilterAggregates.json
@@ -0,0 +1,166 @@
+{
+ "general_aggregate_with_filter_where": {
+ "tables": {
+ "tbl": {
+ "schema": [
+ {"name": "int_col", "type": "INT"},
+ {"name": "double_col", "type": "DOUBLE"},
+ {"name": "string_col", "type": "STRING"},
+ {"name": "bool_col", "type": "BOOLEAN"}
+ ],
+ "inputs": [
+ [2, 300, "a", true],
+ [2, 400, "a", true],
+ [3, 100, "b", false],
+ [100, 1, "b", false],
+ [101, 1.01, "c", false],
+ [150, 1.5, "c", false],
+ [175, 1.75, "c", true]
+ ]
+ }
+ },
+ "queries": [
+ {
+ "ignored": true,
+ "comments": "FILTER WHERE clause with IN hard-wired to translate into subquery, which in this case should not happen.",
+ "sql": "SELECT min(double_col) FILTER (WHERE string_col IN ('a', 'b')), count(*) FROM {tbl}"
+ },
+ {
+ "ignored": true,
+ "comments": "IS NULL and IS NOT NULL is not yet supported in filter conversion.",
+ "sql": "SELECT min(double_col) FILTER (WHERE string_col IS NOT NULL), count(*) FROM {tbl}"
+ },
+ {
+ "ignored": true,
+ "comments": "agg with filter and group-by causes conversion issue on v1 if the group-by field is not in the select list",
+ "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col > 10) FROM {tbl} GROUP BY int_col"
+ },
+ {
+ "ignored": true,
+ "comments": "agg with group by and filter will create NULL-able columns that are unsupported with current AGG FILTER WHERE semantics.",
+ "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col > 10), count(*) FROM {tbl} GROUP BY int_col, string_col"
+ },
+ {
+ "ignored": true,
+ "comments": "mixed/conflict filter that requires merging in v1 is not supported",
+ "sql": "SELECT double_col, bool_col, count(int_col) FILTER (WHERE string_col = 'a' OR int_col > 10) FROM {tbl} WHERE string_col = 'b' GROUP BY double_col, bool_col"
+ },
+ {
+ "ignored": true,
+ "comments": "FILTER WHERE clause might omit group key entirely if nothing is being selected out, this is non-standard SQL behavior but it is v1 behavior",
+ "sql": "SELECT int_col, count(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10) FROM {tbl} GROUP BY int_col"
+ },
+ { "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col > 10) FROM {tbl}" },
+ { "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col > 10), count(*) FROM {tbl}" },
+ { "sql": "SELECT min(int_col) FILTER (WHERE bool_col IS TRUE), max(int_col) FILTER (WHERE bool_col AND int_col < 10), avg(int_col) FILTER (WHERE MOD(int_col, 3) = 0), sum(int_col), count(int_col), count(distinct(int_col)), count(*) FILTER (WHERE MOD(int_col, 3) = 0) FROM {tbl}" },
+ { "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col > 10) FROM {tbl} WHERE string_col='b'" },
+ { "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col > 10), count(*) FROM {tbl} WHERE string_col='b'" },
+ { "sql": "SELECT int_col, COALESCE(count(double_col) FILTER (WHERE string_col = 'a' OR int_col > 0), 0), count(*) FROM {tbl} GROUP BY int_col" },
+ {
+ "ignored": true,
+ "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched",
+ "sql": "SELECT int_col, string_col, COALESCE(min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), 0), COALESCE(max(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10), 0), avg(double_col), sum(double_col), count(double_col), COALESCE(count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col > 10), 0) FROM {tbl} GROUP BY int_col, string_col"
+ },
+ {
+ "ignored": true,
+ "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched",
+ "sql": "SELECT double_col, COALESCE(min(int_col) FILTER (WHERE bool_col IS TRUE), 0), COALESCE(max(int_col) FILTER (WHERE bool_col AND int_col < 10), 0), COALESCE(avg(int_col) FILTER (WHERE MOD(int_col, 3) = 0), 0), sum(int_col), count(int_col), count(distinct(int_col)), count(string_col) FILTER (WHERE MOD(int_col, 3) = 0) FROM {tbl} GROUP BY double_col"
+ },
+ { "sql": "SELECT double_col, bool_col, count(int_col) FILTER (WHERE string_col = 'a' OR int_col > 10), count(int_col) FROM {tbl} WHERE string_col IN ('a', 'b') GROUP BY double_col, bool_col" },
+ {
+ "ignored": true,
+ "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched",
+ "sql": "SELECT bool_col, COALESCE(min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), 0), COALESCE(max(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10), 0), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col > 10), count(string_col) FROM {tbl} WHERE string_col='b' GROUP BY bool_col"
+ }
+ ]
+ },
+ "general_aggregate_with_filter_where_after_join": {
+ "tables": {
+ "tbl1": {
+ "schema": [
+ {"name": "int_col", "type": "INT"},
+ {"name": "double_col", "type": "DOUBLE"},
+ {"name": "string_col", "type": "STRING"},
+ {"name": "bool_col", "type": "BOOLEAN"}
+ ],
+ "inputs": [
+ [2, 300, "a", true],
+ [2, 400, "a", true],
+ [3, 100, "b", false],
+ [100, 1, "b", false],
+ [101, 1.01, "c", false],
+ [150, 1.5, "c", false],
+ [175, 1.75, "c", true]
+ ]
+ },
+ "tbl2": {
+ "schema":[
+ {"name": "int_col2", "type": "INT"},
+ {"name": "string_col2", "type": "STRING"},
+ {"name": "double_col2", "type": "DOUBLE"}
+ ],
+ "inputs": [
+ [1, "apple", 1000.0],
+ [2, "a", 1.323],
+ [3, "b", 1212.12],
+ [3, "c", 341],
+ [4, "orange", 1212.121]
+ ]
+ }
+ },
+ "queries": [
+ {
+ "ignored": true,
+ "comments": "FILTER WHERE clause with IN hard-wired to translate into subquery, which in this case should not happen.",
+ "sql": "SELECT min(double_col) FILTER (WHERE string_col IN ('a', 'b')), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2"
+ },
+ {
+ "ignored": true,
+ "comments": "IS NULL and IS NOT NULL is not yet supported in filter conversion.",
+ "sql": "SELECT min(double_col) FILTER (WHERE string_col IS NOT NULL), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2"
+ },
+ {
+ "ignored": true,
+ "comments": "agg with filter and group-by causes conversion issue on v1 if the group-by field is not in the select list",
+ "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col2 > 10) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY int_col2"
+ },
+ {
+ "ignored": true,
+ "comments": "agg with group by and filter will create NULL-able columns that are unsupported with current AGG FILTER WHERE semantics.",
+ "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col2 > 10), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY int_col2, string_col"
+ },
+ {
+ "ignored": true,
+ "comments": "mixed/conflict filter that requires merging in v1 is not supported",
+ "sql": "SELECT double_col, bool_col, count(int_col2) FILTER (WHERE string_col = 'a' OR int_col2 > 10) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 WHERE string_col = 'b' GROUP BY double_col, bool_col"
+ },
+ {
+ "ignored": true,
+ "comments": "FILTER WHERE clause might omit group key entirely if nothing is being selected out, this is non-standard SQL behavior but it is v1 behavior",
+ "sql": "SELECT int_col2, count(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY int_col2"
+ },
+ { "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col2 > 10) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2" },
+ { "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col2 > 10), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2" },
+ { "sql": "SELECT min(int_col2) FILTER (WHERE bool_col IS TRUE), max(int_col2) FILTER (WHERE bool_col AND int_col2 < 10), avg(int_col2) FILTER (WHERE MOD(int_col2, 3) = 0), sum(int_col2), count(int_col2), count(distinct(int_col2)), count(*) FILTER (WHERE MOD(int_col2, 3) = 0) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2" },
+ { "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col2 > 10) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 WHERE string_col='b'" },
+ { "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col2 > 10), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 WHERE string_col='b'" },
+ { "sql": "SELECT int_col2, COALESCE(count(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 0), 0), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY int_col2" },
+ {
+ "ignored": true,
+ "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched",
+ "sql": "SELECT int_col2, string_col, COALESCE(min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), 0), COALESCE(max(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10), 0), avg(double_col), sum(double_col), count(double_col), COALESCE(count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col2 > 10), 0) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY int_col2, string_col"
+ },
+ {
+ "ignored": true,
+ "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched",
+ "sql": "SELECT double_col, COALESCE(min(int_col2) FILTER (WHERE bool_col IS TRUE), 0), COALESCE(max(int_col2) FILTER (WHERE bool_col AND int_col2 < 10), 0), COALESCE(avg(int_col2) FILTER (WHERE MOD(int_col2, 3) = 0), 0), sum(int_col2), count(int_col2), count(distinct(int_col2)), count(string_col) FILTER (WHERE MOD(int_col2, 3) = 0) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY double_col"
+ },
+ { "sql": "SELECT double_col, bool_col, count(int_col2) FILTER (WHERE string_col = 'a' OR int_col2 > 10), count(int_col2) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 WHERE string_col IN ('a', 'b') GROUP BY double_col, bool_col" },
+ {
+ "ignored": true,
+ "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched",
+ "sql": "SELECT bool_col, COALESCE(min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), 0), COALESCE(max(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10), 0), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col2 > 10), count(string_col) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 WHERE string_col='b' GROUP BY bool_col"
+ }
+ ]
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org