You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by si...@apache.org on 2023/02/02 18:39:05 UTC

[pinot] branch master updated: [multistage][bugfix] fix non-grouping agg result issue (#10222)

This is an automated email from the ASF dual-hosted git repository.

siddteotia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new c70fcaa282 [multistage][bugfix] fix non-grouping agg result issue (#10222)
c70fcaa282 is described below

commit c70fcaa28299622e7730e84e7991eebd5654c79a
Author: Rong Rong <ro...@apache.org>
AuthorDate: Thu Feb 2 10:38:59 2023 -0800

    [multistage][bugfix] fix non-grouping agg result issue (#10222)
    
    * adding test queries first
    
    * add non-group-by agg default value
    
    ---------
    
    Co-authored-by: Rong Rong <ro...@startree.ai>
---
 .../query/runtime/operator/AggregateOperator.java  | 38 +++++++++++++++-------
 .../runtime/operator/AggregateOperatorTest.java    |  4 +--
 .../src/test/resources/queries/Aggregates.json     | 28 ++++++++++++++++
 3 files changed, 57 insertions(+), 13 deletions(-)

diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 61e7f3dec7..90e00a943a 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -168,13 +169,28 @@ public class AggregateOperator extends MultiStageOperator {
     }
     _hasReturnedAggregateBlock = true;
     if (rows.size() == 0) {
-      return TransferableBlockUtils.getEndOfStreamTransferableBlock();
+      if (_groupSet.size() == 0) {
+        return constructEmptyAggResultBlock();
+      } else {
+        return TransferableBlockUtils.getEndOfStreamTransferableBlock();
+      }
     } else {
       _operatorStats.recordOutput(1, rows.size());
       return new TransferableBlock(rows, _resultSchema, DataBlock.Type.ROW);
     }
   }
 
+  /**
+   * @return an empty agg result block for non-group-by aggregation.
+   */
+  private TransferableBlock constructEmptyAggResultBlock() {
+    Object[] row = new Object[_aggCalls.size()];
+    for (int i = 0; i < _accumulators.length; i++) {
+      row[i] = _accumulators[i]._merger.initialize(null, _accumulators[i]._dataType);
+    }
+    return new TransferableBlock(Collections.singletonList(row), _resultSchema, DataBlock.Type.ROW);
+  }
+
   /**
    * @return whether or not the operator is ready to move on (EOS or ERROR)
    */
@@ -258,7 +274,7 @@ public class AggregateOperator extends MultiStageOperator {
     }
 
     @Override
-    public Object initialize(Object other) {
+    public Object initialize(Object other, DataSchema.ColumnDataType dataType) {
       PinotFourthMoment moment = new PinotFourthMoment();
       moment.increment(((Number) other).doubleValue());
       return moment;
@@ -285,7 +301,7 @@ public class AggregateOperator extends MultiStageOperator {
     }
 
     @Override
-    public Object initialize(Object other) {
+    public Object initialize(Object other, DataSchema.ColumnDataType dataType) {
       ObjectOpenHashSet<Object> set = new ObjectOpenHashSet<>();
       set.add(other);
       return set;
@@ -307,12 +323,12 @@ public class AggregateOperator extends MultiStageOperator {
     /**
      * Initializes the merger based on the first input
      */
-    default Object initialize(Object other) {
-      return other;
+    default Object initialize(Object other, DataSchema.ColumnDataType dataType) {
+      return other == null ? dataType.getNullPlaceholder() : other;
     }
 
     /**
-     * Merges the existing aggregate (the result of {@link #initialize(Object)}) with
+     * Merges the existing aggregate (the result of {@link #initialize(Object, DataSchema.ColumnDataType)}) with
      * the new value coming in (which may be an aggregate in and of itself).
      */
     Object merge(Object agg, Object value);
@@ -353,22 +369,22 @@ public class AggregateOperator extends MultiStageOperator {
     final Object _literal;
     final Map<Key, Object> _results = new HashMap<>();
     final Merger _merger;
+    final DataSchema.ColumnDataType _dataType;
 
     Accumulator(RexExpression.FunctionCall aggCall, Map<String, Function<DataSchema.ColumnDataType, Merger>> merger,
         String functionName, DataSchema inputSchema) {
       // agg function operand should either be a InputRef or a Literal
-      DataSchema.ColumnDataType dataType;
       RexExpression rexExpression = toAggregationFunctionOperand(aggCall);
       if (rexExpression instanceof RexExpression.InputRef) {
         _inputRef = ((RexExpression.InputRef) rexExpression).getIndex();
         _literal = null;
-        dataType = inputSchema.getColumnDataType(_inputRef);
+        _dataType = inputSchema.getColumnDataType(_inputRef);
       } else {
         _inputRef = -1;
         _literal = ((RexExpression.Literal) rexExpression).getValue();
-        dataType = DataSchema.ColumnDataType.fromDataType(rexExpression.getDataType(), false);
+        _dataType = DataSchema.ColumnDataType.fromDataType(rexExpression.getDataType(), true);
       }
-      _merger = merger.get(functionName).apply(dataType);
+      _merger = merger.get(functionName).apply(_dataType);
     }
 
     void accumulate(Key key, Object[] row) {
@@ -377,7 +393,7 @@ public class AggregateOperator extends MultiStageOperator {
       Object value = _inputRef == -1 ? _literal : row[_inputRef];
 
       if (currentRes == null) {
-        _results.put(key, _merger.initialize(value));
+        _results.put(key, _merger.initialize(value, _dataType));
       } else {
         Object mergedResult = _merger.merge(currentRes, value);
         _results.put(key, mergedResult);
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
index 6f2212e6b6..0bda46470e 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
@@ -190,7 +190,7 @@ public class AggregateOperatorTest {
 
     AggregateOperator.Merger merger = Mockito.mock(AggregateOperator.Merger.class);
     Mockito.when(merger.merge(Mockito.any(), Mockito.any())).thenReturn(12d);
-    Mockito.when(merger.initialize(Mockito.any())).thenReturn(1d);
+    Mockito.when(merger.initialize(Mockito.any(), Mockito.any())).thenReturn(1d);
     DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
     AggregateOperator operator =
         new AggregateOperator(_input, outSchema, calls, group, inSchema, ImmutableMap.of("SUM", cdt -> merger), 1, 2);
@@ -201,7 +201,7 @@ public class AggregateOperatorTest {
     // Then:
     // should call merger twice, one from second row in first block and two from the first row
     // in second block
-    Mockito.verify(merger, Mockito.times(1)).initialize(Mockito.any());
+    Mockito.verify(merger, Mockito.times(1)).initialize(Mockito.any(), Mockito.any());
     Mockito.verify(merger, Mockito.times(2)).merge(Mockito.any(), Mockito.any());
     Assert.assertEquals(resultBlock.getContainer().get(0), new Object[]{1, 12d},
         "Expected two columns (group by key, agg value)");
diff --git a/pinot-query-runtime/src/test/resources/queries/Aggregates.json b/pinot-query-runtime/src/test/resources/queries/Aggregates.json
index e4243d2fc0..6ce21afbec 100644
--- a/pinot-query-runtime/src/test/resources/queries/Aggregates.json
+++ b/pinot-query-runtime/src/test/resources/queries/Aggregates.json
@@ -138,6 +138,34 @@
         "comment": "issue with converting data types:  Unexpected RelDataTypeField: ANY for column: EXPR$0",
         "description": "sum with inner function",
         "sql": "SELECT sum(pow(int_col, 2)) FROM {tbl}"
+      },
+      {
+        "ignored": true,
+        "comment": "sum empty returns [0] instead of [null] at the moment",
+        "description": "sum empty input after filter",
+        "sql": "SELECT sum(int_col) FROM {tbl} WHERE string_col IN ('foo', 'bar')"
+      },
+      {
+        "description": "count empty input after filter",
+        "sql": "SELECT count(*) FROM {tbl} WHERE string_col IN ('foo', 'bar')"
+      },
+      {
+        "description": "count empty input after filter",
+        "sql": "SELECT count(int_col) FROM {tbl} WHERE string_col IN ('foo', 'bar')"
+      },
+      {
+        "ignored": true,
+        "comment": "sum empty returns [0] instead of [null] at the moment",
+        "description": "sum empty input after filter with subquery",
+        "sql": "SELECT sum(int_col) FROM {tbl} WHERE string_col IN ( SELECT string_col FROM {tbl} WHERE int_col BETWEEN 1 AND 0 GROUP BY string_col )"
+      },
+      {
+        "description": "count empty input after filter with sub-query",
+        "sql": "SELECT count(*) FROM {tbl} WHERE string_col IN ( SELECT string_col FROM {tbl} WHERE int_col BETWEEN 1 AND 0 GROUP BY string_col )"
+      },
+      {
+        "description": "count empty input after filter with sub-query",
+        "sql": "SELECT count(int_col) FROM {tbl} WHERE string_col IN ( SELECT string_col FROM {tbl} WHERE int_col BETWEEN 1 AND 0 GROUP BY string_col )"
       }
     ]
   },


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org