You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by si...@apache.org on 2023/02/02 18:39:05 UTC
[pinot] branch master updated: [multistage][bugfix] fix non-grouping agg result issue (#10222)
This is an automated email from the ASF dual-hosted git repository.
siddteotia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new c70fcaa282 [multistage][bugfix] fix non-grouping agg result issue (#10222)
c70fcaa282 is described below
commit c70fcaa28299622e7730e84e7991eebd5654c79a
Author: Rong Rong <ro...@apache.org>
AuthorDate: Thu Feb 2 10:38:59 2023 -0800
[multistage][bugfix] fix non-grouping agg result issue (#10222)
* adding test queries first
* add non-group-by agg default value
---------
Co-authored-by: Rong Rong <ro...@startree.ai>
---
.../query/runtime/operator/AggregateOperator.java | 38 +++++++++++++++-------
.../runtime/operator/AggregateOperatorTest.java | 4 +--
.../src/test/resources/queries/Aggregates.json | 28 ++++++++++++++++
3 files changed, 57 insertions(+), 13 deletions(-)
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 61e7f3dec7..90e00a943a 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -168,13 +169,28 @@ public class AggregateOperator extends MultiStageOperator {
}
_hasReturnedAggregateBlock = true;
if (rows.size() == 0) {
- return TransferableBlockUtils.getEndOfStreamTransferableBlock();
+ if (_groupSet.size() == 0) {
+ return constructEmptyAggResultBlock();
+ } else {
+ return TransferableBlockUtils.getEndOfStreamTransferableBlock();
+ }
} else {
_operatorStats.recordOutput(1, rows.size());
return new TransferableBlock(rows, _resultSchema, DataBlock.Type.ROW);
}
}
+ /**
+ * @return an empty agg result block for non-group-by aggregation.
+ */
+ private TransferableBlock constructEmptyAggResultBlock() {
+ Object[] row = new Object[_aggCalls.size()];
+ for (int i = 0; i < _accumulators.length; i++) {
+ row[i] = _accumulators[i]._merger.initialize(null, _accumulators[i]._dataType);
+ }
+ return new TransferableBlock(Collections.singletonList(row), _resultSchema, DataBlock.Type.ROW);
+ }
+
/**
* @return whether or not the operator is ready to move on (EOS or ERROR)
*/
@@ -258,7 +274,7 @@ public class AggregateOperator extends MultiStageOperator {
}
@Override
- public Object initialize(Object other) {
+ public Object initialize(Object other, DataSchema.ColumnDataType dataType) {
PinotFourthMoment moment = new PinotFourthMoment();
moment.increment(((Number) other).doubleValue());
return moment;
@@ -285,7 +301,7 @@ public class AggregateOperator extends MultiStageOperator {
}
@Override
- public Object initialize(Object other) {
+ public Object initialize(Object other, DataSchema.ColumnDataType dataType) {
ObjectOpenHashSet<Object> set = new ObjectOpenHashSet<>();
set.add(other);
return set;
@@ -307,12 +323,12 @@ public class AggregateOperator extends MultiStageOperator {
/**
* Initializes the merger based on the first input
*/
- default Object initialize(Object other) {
- return other;
+ default Object initialize(Object other, DataSchema.ColumnDataType dataType) {
+ return other == null ? dataType.getNullPlaceholder() : other;
}
/**
- * Merges the existing aggregate (the result of {@link #initialize(Object)}) with
+ * Merges the existing aggregate (the result of {@link #initialize(Object, DataSchema.ColumnDataType)}) with
* the new value coming in (which may be an aggregate in and of itself).
*/
Object merge(Object agg, Object value);
@@ -353,22 +369,22 @@ public class AggregateOperator extends MultiStageOperator {
final Object _literal;
final Map<Key, Object> _results = new HashMap<>();
final Merger _merger;
+ final DataSchema.ColumnDataType _dataType;
Accumulator(RexExpression.FunctionCall aggCall, Map<String, Function<DataSchema.ColumnDataType, Merger>> merger,
String functionName, DataSchema inputSchema) {
// agg function operand should either be a InputRef or a Literal
- DataSchema.ColumnDataType dataType;
RexExpression rexExpression = toAggregationFunctionOperand(aggCall);
if (rexExpression instanceof RexExpression.InputRef) {
_inputRef = ((RexExpression.InputRef) rexExpression).getIndex();
_literal = null;
- dataType = inputSchema.getColumnDataType(_inputRef);
+ _dataType = inputSchema.getColumnDataType(_inputRef);
} else {
_inputRef = -1;
_literal = ((RexExpression.Literal) rexExpression).getValue();
- dataType = DataSchema.ColumnDataType.fromDataType(rexExpression.getDataType(), false);
+ _dataType = DataSchema.ColumnDataType.fromDataType(rexExpression.getDataType(), true);
}
- _merger = merger.get(functionName).apply(dataType);
+ _merger = merger.get(functionName).apply(_dataType);
}
void accumulate(Key key, Object[] row) {
@@ -377,7 +393,7 @@ public class AggregateOperator extends MultiStageOperator {
Object value = _inputRef == -1 ? _literal : row[_inputRef];
if (currentRes == null) {
- _results.put(key, _merger.initialize(value));
+ _results.put(key, _merger.initialize(value, _dataType));
} else {
Object mergedResult = _merger.merge(currentRes, value);
_results.put(key, mergedResult);
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
index 6f2212e6b6..0bda46470e 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
@@ -190,7 +190,7 @@ public class AggregateOperatorTest {
AggregateOperator.Merger merger = Mockito.mock(AggregateOperator.Merger.class);
Mockito.when(merger.merge(Mockito.any(), Mockito.any())).thenReturn(12d);
- Mockito.when(merger.initialize(Mockito.any())).thenReturn(1d);
+ Mockito.when(merger.initialize(Mockito.any(), Mockito.any())).thenReturn(1d);
DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
AggregateOperator operator =
new AggregateOperator(_input, outSchema, calls, group, inSchema, ImmutableMap.of("SUM", cdt -> merger), 1, 2);
@@ -201,7 +201,7 @@ public class AggregateOperatorTest {
// Then:
// should call merger twice, one from second row in first block and two from the first row
// in second block
- Mockito.verify(merger, Mockito.times(1)).initialize(Mockito.any());
+ Mockito.verify(merger, Mockito.times(1)).initialize(Mockito.any(), Mockito.any());
Mockito.verify(merger, Mockito.times(2)).merge(Mockito.any(), Mockito.any());
Assert.assertEquals(resultBlock.getContainer().get(0), new Object[]{1, 12d},
"Expected two columns (group by key, agg value)");
diff --git a/pinot-query-runtime/src/test/resources/queries/Aggregates.json b/pinot-query-runtime/src/test/resources/queries/Aggregates.json
index e4243d2fc0..6ce21afbec 100644
--- a/pinot-query-runtime/src/test/resources/queries/Aggregates.json
+++ b/pinot-query-runtime/src/test/resources/queries/Aggregates.json
@@ -138,6 +138,34 @@
"comment": "issue with converting data types: Unexpected RelDataTypeField: ANY for column: EXPR$0",
"description": "sum with inner function",
"sql": "SELECT sum(pow(int_col, 2)) FROM {tbl}"
+ },
+ {
+ "ignored": true,
+ "comment": "sum empty returns [0] instead of [null] at the moment",
+ "description": "sum empty input after filter",
+ "sql": "SELECT sum(int_col) FROM {tbl} WHERE string_col IN ('foo', 'bar')"
+ },
+ {
+ "description": "count empty input after filter",
+ "sql": "SELECT count(*) FROM {tbl} WHERE string_col IN ('foo', 'bar')"
+ },
+ {
+ "description": "count empty input after filter",
+ "sql": "SELECT count(int_col) FROM {tbl} WHERE string_col IN ('foo', 'bar')"
+ },
+ {
+ "ignored": true,
+ "comment": "sum empty returns [0] instead of [null] at the moment",
+ "description": "sum empty input after filter with subquery",
+ "sql": "SELECT sum(int_col) FROM {tbl} WHERE string_col IN ( SELECT string_col FROM {tbl} WHERE int_col BETWEEN 1 AND 0 GROUP BY string_col )"
+ },
+ {
+ "description": "count empty input after filter with sub-query",
+ "sql": "SELECT count(*) FROM {tbl} WHERE string_col IN ( SELECT string_col FROM {tbl} WHERE int_col BETWEEN 1 AND 0 GROUP BY string_col )"
+ },
+ {
+ "description": "count empty input after filter with sub-query",
+ "sql": "SELECT count(int_col) FROM {tbl} WHERE string_col IN ( SELECT string_col FROM {tbl} WHERE int_col BETWEEN 1 AND 0 GROUP BY string_col )"
}
]
},
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org