You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ro...@apache.org on 2023/07/28 20:36:05 UTC

[pinot] branch master updated: [multistage][agg] support agg with filter (#11144)

This is an automated email from the ASF dual-hosted git repository.

rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 05989559e0 [multistage][agg] support agg with filter (#11144)
05989559e0 is described below

commit 05989559e06d448d23d55147672345ede624381d
Author: Rong Rong <ro...@apache.org>
AuthorDate: Fri Jul 28 13:35:59 2023 -0700

    [multistage][agg] support agg with filter (#11144)
    
    * [init][agg] support agg filter where clause
    * limitation still applies for nullable vs non-nullable results and agg filter merging with select filter, will be address in follow ups.
    
    ---------
    
    Co-authored-by: Rong Rong <ro...@startree.ai>
---
 .../pinot/common/datablock/DataBlockUtils.java     | 311 +++++++++++++++++++++
 .../PinotAggregateExchangeNodeInsertRule.java      |   2 +-
 .../query/parser/CalciteRexExpressionParser.java   |  30 +-
 .../query/planner/plannode/AggregateNode.java      |   7 +
 .../query/runtime/operator/AggregateOperator.java  |  75 +++--
 .../operator/MultistageAggregationExecutor.java    |  27 +-
 .../operator/MultistageGroupByExecutor.java        |  83 +++++-
 .../runtime/operator/block/DataBlockValSet.java    |  15 +-
 .../operator/block/FilteredDataBlockValSet.java    |  33 +--
 .../query/runtime/plan/PhysicalPlanVisitor.java    |   2 +-
 .../plan/server/ServerPlanRequestVisitor.java      |   6 +-
 .../runtime/operator/AggregateOperatorTest.java    |  16 +-
 .../test/resources/queries/FilterAggregates.json   | 166 +++++++++++
 13 files changed, 679 insertions(+), 94 deletions(-)

diff --git a/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java
index 99c3e4df46..b5d8321280 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/datablock/DataBlockUtils.java
@@ -23,6 +23,7 @@ import java.math.BigDecimal;
 import java.nio.ByteBuffer;
 import java.sql.Timestamp;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
@@ -560,4 +561,314 @@ public final class DataBlockUtils {
 
     return rows;
   }
+
+  /**
+   * Given a datablock and the column index, extracts the integer values for the column. Prefer using this function over
+   * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the
+   * desired type.
+   * This only works on ROW format.
+   * TODO: Add support for COLUMNAR format.
+   * @return int array of values in the column
+   */
+  public static int[] extractIntValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) {
+    DataSchema dataSchema = dataBlock.getDataSchema();
+    DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
+
+    // Get null bitmap for the column.
+    RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex];
+    int numRows = dataBlock.getNumberOfRows();
+
+    int[] rows = new int[numRows];
+    int outRowId = 0;
+    for (int inRowId = 0; inRowId < numRows; inRowId++) {
+      if (dataBlock.getInt(inRowId, filterArgIdx) == 1) {
+        if (nullBitmap != null && nullBitmap.contains(inRowId)) {
+          outRowId++;
+          continue;
+        }
+        switch (columnDataTypes[columnIndex]) {
+          case INT:
+          case BOOLEAN:
+            rows[outRowId++] = dataBlock.getInt(inRowId, columnIndex);
+            break;
+          case LONG:
+            rows[outRowId++] = (int) dataBlock.getLong(inRowId, columnIndex);
+            break;
+          case FLOAT:
+            rows[outRowId++] = (int) dataBlock.getFloat(inRowId, columnIndex);
+            break;
+          case DOUBLE:
+            rows[outRowId++] = (int) dataBlock.getDouble(inRowId, columnIndex);
+            break;
+          case BIG_DECIMAL:
+            rows[outRowId++] = dataBlock.getBigDecimal(inRowId, columnIndex).intValue();
+            break;
+          default:
+            throw new IllegalStateException(
+                String.format("Unsupported data type: %s for column: %s", columnDataTypes[columnIndex], columnIndex));
+        }
+      }
+    }
+    return Arrays.copyOfRange(rows, 0, outRowId);
+  }
+
+  /**
+   * Given a datablock and the column index, extracts the long values for the column. Prefer using this function over
+   * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the
+   * desired type.
+   * This only works on ROW format.
+   * TODO: Add support for COLUMNAR format.
+   * @return long array of values in the column
+   */
+  public static long[] extractLongValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) {
+    DataSchema dataSchema = dataBlock.getDataSchema();
+    DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
+
+    // Get null bitmap for the column.
+    RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex];
+    int numRows = dataBlock.getNumberOfRows();
+
+    long[] rows = new long[numRows];
+    int outRowId = 0;
+    for (int inRowId = 0; inRowId < numRows; inRowId++) {
+      if (dataBlock.getInt(inRowId, filterArgIdx) == 1) {
+        if (nullBitmap != null && nullBitmap.contains(inRowId)) {
+          outRowId++;
+          continue;
+        }
+        switch (columnDataTypes[columnIndex]) {
+          case INT:
+          case BOOLEAN:
+            rows[outRowId++] = dataBlock.getInt(inRowId, columnIndex);
+            break;
+          case LONG:
+            rows[outRowId++] = dataBlock.getLong(inRowId, columnIndex);
+            break;
+          case FLOAT:
+            rows[outRowId++] = (long) dataBlock.getFloat(inRowId, columnIndex);
+            break;
+          case DOUBLE:
+            rows[outRowId++] = (long) dataBlock.getDouble(inRowId, columnIndex);
+            break;
+          case BIG_DECIMAL:
+            rows[outRowId++] = dataBlock.getBigDecimal(inRowId, columnIndex).longValue();
+            break;
+          default:
+            throw new IllegalStateException(
+                String.format("Unsupported data type: %s for column: %s", columnDataTypes[columnIndex], columnIndex));
+        }
+      }
+    }
+    return Arrays.copyOfRange(rows, 0, outRowId);
+  }
+
+  /**
+   * Given a datablock and the column index, extracts the float values for the column. Prefer using this function over
+   * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the
+   * desired type.
+   * This only works on ROW format.
+   * TODO: Add support for COLUMNAR format.
+   * @return float array of values in the column
+   */
+  public static float[] extractFloatValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) {
+    DataSchema dataSchema = dataBlock.getDataSchema();
+    DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
+
+    // Get null bitmap for the column.
+    RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex];
+    int numRows = dataBlock.getNumberOfRows();
+
+    float[] rows = new float[numRows];
+    int outRowId = 0;
+    for (int inRowId = 0; inRowId < numRows; inRowId++) {
+      if (dataBlock.getInt(inRowId, filterArgIdx) == 1) {
+        if (nullBitmap != null && nullBitmap.contains(inRowId)) {
+          outRowId++;
+          continue;
+        }
+        switch (columnDataTypes[columnIndex]) {
+          case INT:
+          case BOOLEAN:
+            rows[outRowId++] = dataBlock.getInt(inRowId, columnIndex);
+            break;
+          case LONG:
+            rows[outRowId++] = dataBlock.getLong(inRowId, columnIndex);
+            break;
+          case FLOAT:
+            rows[outRowId++] = dataBlock.getFloat(inRowId, columnIndex);
+            break;
+          case DOUBLE:
+            rows[outRowId++] = (float) dataBlock.getDouble(inRowId, columnIndex);
+            break;
+          case BIG_DECIMAL:
+            rows[outRowId++] = dataBlock.getBigDecimal(inRowId, columnIndex).floatValue();
+            break;
+          default:
+            throw new IllegalStateException(
+                String.format("Unsupported data type: %s for column: %s", columnDataTypes[columnIndex], columnIndex));
+        }
+      }
+    }
+    return Arrays.copyOfRange(rows, 0, outRowId);
+  }
+
+  /**
+   * Given a datablock and the column index, extracts the double values for the column. Prefer using this function over
+   * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the
+   * desired type.
+   * This only works on ROW format.
+   * TODO: Add support for COLUMNAR format.
+   * @return double array of values in the column
+   */
+  public static double[] extractDoubleValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) {
+    DataSchema dataSchema = dataBlock.getDataSchema();
+    DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
+
+    // Get null bitmap for the column.
+    RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex];
+    int numRows = dataBlock.getNumberOfRows();
+
+    double[] rows = new double[numRows];
+    int outRowId = 0;
+    for (int inRowId = 0; inRowId < numRows; inRowId++) {
+      if (dataBlock.getInt(inRowId, filterArgIdx) == 1) {
+        if (nullBitmap != null && nullBitmap.contains(inRowId)) {
+          outRowId++;
+          continue;
+        }
+        switch (columnDataTypes[columnIndex]) {
+          case INT:
+          case BOOLEAN:
+            rows[outRowId++] = dataBlock.getInt(inRowId, columnIndex);
+            break;
+          case LONG:
+            rows[outRowId++] = dataBlock.getLong(inRowId, columnIndex);
+            break;
+          case FLOAT:
+            rows[outRowId++] = dataBlock.getFloat(inRowId, columnIndex);
+            break;
+          case DOUBLE:
+            rows[outRowId++] = dataBlock.getDouble(inRowId, columnIndex);
+            break;
+          case BIG_DECIMAL:
+            rows[outRowId++] = dataBlock.getBigDecimal(inRowId, columnIndex).doubleValue();
+            break;
+          default:
+            throw new IllegalStateException(
+                String.format("Unsupported data type: %s for column: %s", columnDataTypes[columnIndex], columnIndex));
+        }
+      }
+    }
+    return Arrays.copyOfRange(rows, 0, outRowId);
+  }
+
+  /**
+   * Given a datablock and the column index, extracts the BigDecimal values for the column. Prefer using this function
+   * over extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to
+   * the desired type.
+   * This only works on ROW format.
+   * TODO: Add support for COLUMNAR format.
+   * @return BigDecimal array of values in the column
+   */
+  public static BigDecimal[] extractBigDecimalValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) {
+    DataSchema dataSchema = dataBlock.getDataSchema();
+    DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
+
+    // Get null bitmap for the column.
+    RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex];
+    int numRows = dataBlock.getNumberOfRows();
+
+    BigDecimal[] rows = new BigDecimal[numRows];
+    int outRowId = 0;
+    for (int inRowId = 0; inRowId < numRows; inRowId++) {
+      if (dataBlock.getInt(inRowId, filterArgIdx) == 1) {
+        if (nullBitmap != null && nullBitmap.contains(inRowId)) {
+          outRowId++;
+          continue;
+        }
+        switch (columnDataTypes[columnIndex]) {
+          case INT:
+          case BOOLEAN:
+            rows[outRowId++] = BigDecimal.valueOf(dataBlock.getInt(inRowId, columnIndex));
+            break;
+          case LONG:
+            rows[outRowId++] = BigDecimal.valueOf(dataBlock.getLong(inRowId, columnIndex));
+            break;
+          case FLOAT:
+            rows[outRowId++] = BigDecimal.valueOf(dataBlock.getFloat(inRowId, columnIndex));
+            break;
+          case DOUBLE:
+            rows[outRowId++] = BigDecimal.valueOf(dataBlock.getDouble(inRowId, columnIndex));
+            break;
+          case BIG_DECIMAL:
+            rows[outRowId++] = BigDecimal.valueOf(dataBlock.getBigDecimal(inRowId, columnIndex).doubleValue());
+            break;
+          default:
+            throw new IllegalStateException(
+                String.format("Unsupported data type: %s for column: %s", columnDataTypes[columnIndex], columnIndex));
+        }
+      }
+    }
+    return Arrays.copyOfRange(rows, 0, outRowId);
+  }
+
+  /**
+   * Given a datablock and the column index, extracts the String values for the column. Prefer using this function over
+   * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the
+   * desired type.
+   * This only works on ROW format.
+   * TODO: Add support for COLUMNAR format.
+   * @return String array of values in the column
+   */
+  public static String[] extractStringValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) {
+    DataSchema dataSchema = dataBlock.getDataSchema();
+    DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
+
+    // Get null bitmap for the column.
+    RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex];
+    int numRows = dataBlock.getNumberOfRows();
+
+    String[] rows = new String[numRows];
+    int outRowId = 0;
+    for (int inRowId = 0; inRowId < numRows; inRowId++) {
+      if (dataBlock.getInt(inRowId, filterArgIdx) == 1) {
+        if (nullBitmap != null && nullBitmap.contains(inRowId)) {
+          outRowId++;
+          continue;
+        }
+        rows[outRowId++] = dataBlock.getString(inRowId, columnIndex);
+      }
+    }
+    return Arrays.copyOfRange(rows, 0, outRowId);
+  }
+
+  /**
+   * Given a datablock and the column index, extracts the byte values for the column. Prefer using this function over
+   * extractRowFromDatablock if the desired datatype is known to prevent autoboxing to Object and later unboxing to the
+   * desired type.
+   * This only works on ROW format.
+   * TODO: Add support for COLUMNAR format.
+   * @return byte array of values in the column
+   */
+  public static byte[][] extractBytesValuesForColumn(DataBlock dataBlock, int columnIndex, int filterArgIdx) {
+    DataSchema dataSchema = dataBlock.getDataSchema();
+    DataSchema.ColumnDataType[] columnDataTypes = dataSchema.getColumnDataTypes();
+
+    // Get null bitmap for the column.
+    RoaringBitmap nullBitmap = extractNullBitmaps(dataBlock)[columnIndex];
+    int numRows = dataBlock.getNumberOfRows();
+
+    byte[][] rows = new byte[numRows][];
+    int outRowId = 0;
+    for (int inRowId = 0; inRowId < numRows; inRowId++) {
+      if (dataBlock.getInt(inRowId, filterArgIdx) == 1) {
+        if (nullBitmap != null && nullBitmap.contains(inRowId)) {
+          outRowId++;
+          continue;
+        }
+        rows[outRowId++] = dataBlock.getBytes(inRowId, columnIndex).getBytes();
+      }
+    }
+    return Arrays.copyOfRange(rows, 0, outRowId);
+  }
 }
diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
index 60efa69fde..df904123d2 100644
--- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
+++ b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
@@ -324,7 +324,7 @@ public class PinotAggregateExchangeNodeInsertRule extends RelOptRule {
         orgAggCall.isApproximate(),
         orgAggCall.ignoreNulls(),
         argList,
-        orgAggCall.filterArg,
+        aggType.isInputIntermediateFormat() ? -1 : orgAggCall.filterArg,
         orgAggCall.distinctKeys,
         orgAggCall.collation,
         numberGroups,
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
index 632471f6b0..4c31fc86f4 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.query.parser;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
@@ -67,19 +68,30 @@ public class CalciteRexExpressionParser {
   // Relational conversion Utils
   // --------------------------------------------------------------------------
 
-  public static List<Expression> overwriteSelectList(List<RexExpression> rexNodeList, PinotQuery pinotQuery) {
-    return addSelectList(new ArrayList<>(), rexNodeList, pinotQuery);
-  }
-
-  public static List<Expression> addSelectList(List<Expression> existingList, List<RexExpression> rexNodeList,
-      PinotQuery pinotQuery) {
-    List<Expression> selectExpr = new ArrayList<>(existingList);
-
-    final Iterator<RexExpression> iterator = rexNodeList.iterator();
+  public static List<Expression> convertProjectList(List<RexExpression> projectList, PinotQuery pinotQuery) {
+    List<Expression> selectExpr = new ArrayList<>();
+    final Iterator<RexExpression> iterator = projectList.iterator();
     while (iterator.hasNext()) {
       final RexExpression next = iterator.next();
       selectExpr.add(toExpression(next, pinotQuery));
     }
+    return selectExpr;
+  }
+
+  public static List<Expression> convertAggregateList(List<Expression> groupSetList, List<RexExpression> aggCallList,
+      List<Integer> filterArgIndices, PinotQuery pinotQuery) {
+    List<Expression> selectExpr = new ArrayList<>(groupSetList);
+
+    for (int idx = 0; idx < aggCallList.size(); idx++) {
+      final RexExpression aggCall = aggCallList.get(idx);
+      int filterArgIdx = filterArgIndices.get(idx);
+      if (filterArgIdx == -1) {
+        selectExpr.add(toExpression(aggCall, pinotQuery));
+      } else {
+        selectExpr.add(toExpression(new RexExpression.FunctionCall(SqlKind.FILTER, aggCall.getDataType(), "FILTER",
+            Arrays.asList(aggCall, new RexExpression.InputRef(filterArgIdx))), pinotQuery));
+      }
+    }
 
     return selectExpr;
   }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java
index c465fe93f6..5c8a0999c0 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/plannode/AggregateNode.java
@@ -36,6 +36,8 @@ public class AggregateNode extends AbstractPlanNode {
   @ProtoProperties
   private List<RexExpression> _aggCalls;
   @ProtoProperties
+  private List<Integer> _filterArgIndices;
+  @ProtoProperties
   private List<RexExpression> _groupSet;
   @ProtoProperties
   private AggType _aggType;
@@ -49,6 +51,7 @@ public class AggregateNode extends AbstractPlanNode {
     super(planFragmentId, dataSchema);
     Preconditions.checkState(areHintsValid(relHints), "invalid sql hint for agg node: {}", relHints);
     _aggCalls = aggCalls.stream().map(RexExpression::toRexExpression).collect(Collectors.toList());
+    _filterArgIndices = aggCalls.stream().map(c -> c.filterArg).collect(Collectors.toList());
     _groupSet = groupSet;
     _nodeHint = new NodeHint(relHints);
     _aggType = AggType.valueOf(PinotHintStrategyTable.getHintOption(relHints, PinotHintOptions.INTERNAL_AGG_OPTIONS,
@@ -63,6 +66,10 @@ public class AggregateNode extends AbstractPlanNode {
     return _aggCalls;
   }
 
+  public List<Integer> getFilterArgIndices() {
+    return _filterArgIndices;
+  }
+
   public List<RexExpression> getGroupSet() {
     return _groupSet;
   }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 494ad1d737..22814b4571 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -35,7 +35,6 @@ import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.request.context.FunctionContext;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.common.BlockValSet;
-import org.apache.pinot.core.common.IntermediateStageBlockValSet;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory;
 import org.apache.pinot.query.planner.logical.LiteralHintUtils;
@@ -44,7 +43,10 @@ import org.apache.pinot.query.planner.plannode.AbstractPlanNode;
 import org.apache.pinot.query.planner.plannode.AggregateNode.AggType;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
+import org.apache.pinot.query.runtime.operator.block.DataBlockValSet;
+import org.apache.pinot.query.runtime.operator.block.FilteredDataBlockValSet;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
+import org.apache.pinot.spi.data.FieldSpec;
 
 
 /**
@@ -90,31 +92,28 @@ public class AggregateOperator extends MultiStageOperator {
   private MultistageAggregationExecutor _aggregationExecutor;
   private MultistageGroupByExecutor _groupByExecutor;
 
-  // TODO: refactor Pinot Reducer code to support the intermediate stage agg operator.
-  // aggCalls has to be a list of FunctionCall and cannot be null
-  // groupSet has to be a list of InputRef and cannot be null
-  // TODO: Add these two checks when we confirm we can handle error in upstream ctor call.
-  public AggregateOperator(OpChainExecutionContext context, MultiStageOperator inputOperator,
-      DataSchema resultSchema, DataSchema inputSchema, List<RexExpression> aggCalls, List<RexExpression> groupSet,
-      AggType aggType) {
-    this(context, inputOperator, resultSchema, inputSchema, aggCalls, groupSet, aggType,
-        new AbstractPlanNode.NodeHint());
-  }
-
   @VisibleForTesting
-  public AggregateOperator(OpChainExecutionContext context, MultiStageOperator inputOperator,
-      DataSchema resultSchema, DataSchema inputSchema, List<RexExpression> aggCalls, List<RexExpression> groupSet,
-      AggType aggType, AbstractPlanNode.NodeHint nodeHint) {
+  public AggregateOperator(OpChainExecutionContext context, MultiStageOperator inputOperator, DataSchema resultSchema,
+      DataSchema inputSchema, List<RexExpression> aggCalls, List<RexExpression> groupSet, AggType aggType,
+      @Nullable List<Integer> filterArgIndices, @Nullable AbstractPlanNode.NodeHint nodeHint) {
     super(context);
     _inputOperator = inputOperator;
     _resultSchema = resultSchema;
     _inputSchema = inputSchema;
     _aggType = aggType;
+    // filter arg index array
+    int[] filterArgIndexArray;
+    if (filterArgIndices == null || filterArgIndices.size() == 0) {
+      filterArgIndexArray = null;
+    } else {
+      filterArgIndexArray = filterArgIndices.stream().mapToInt(Integer::intValue).toArray();
+    }
+    // filter operand and literal hints
     if (nodeHint != null && nodeHint._hintOptions != null
         && nodeHint._hintOptions.get(PinotHintOptions.INTERNAL_AGG_OPTIONS) != null) {
-      _aggCallSignatureMap = LiteralHintUtils.hintStringToLiteralMap(nodeHint._hintOptions
-          .get(PinotHintOptions.INTERNAL_AGG_OPTIONS)
-          .get(PinotHintOptions.InternalAggregateOptions.AGG_CALL_SIGNATURE));
+      _aggCallSignatureMap = LiteralHintUtils.hintStringToLiteralMap(
+          nodeHint._hintOptions.get(PinotHintOptions.INTERNAL_AGG_OPTIONS)
+              .get(PinotHintOptions.InternalAggregateOptions.AGG_CALL_SIGNATURE));
     } else {
       _aggCallSignatureMap = Collections.emptyMap();
     }
@@ -126,6 +125,7 @@ public class AggregateOperator extends MultiStageOperator {
 
     // Convert groupSet to ExpressionContext that our aggregation functions understand.
     List<ExpressionContext> groupByExpr = getGroupSet(groupSet);
+
     List<FunctionContext> functionContexts = getFunctionContexts(aggCalls);
     AggregationFunction[] aggFunctions = new AggregationFunction[functionContexts.size()];
 
@@ -136,12 +136,14 @@ public class AggregateOperator extends MultiStageOperator {
     // Initialize the appropriate executor.
     if (!groupSet.isEmpty()) {
       _isGroupByAggregation = true;
-      _groupByExecutor = new MultistageGroupByExecutor(groupByExpr, aggFunctions, aggType, _colNameToIndexMap,
-          _resultSchema);
+      _groupByExecutor =
+          new MultistageGroupByExecutor(groupByExpr, aggFunctions, filterArgIndexArray, aggType, _colNameToIndexMap,
+              _resultSchema);
     } else {
       _isGroupByAggregation = false;
-      _aggregationExecutor = new MultistageAggregationExecutor(aggFunctions, aggType, _colNameToIndexMap,
-          _resultSchema);
+      _aggregationExecutor =
+          new MultistageAggregationExecutor(aggFunctions, filterArgIndexArray, aggType, _colNameToIndexMap,
+              _resultSchema);
     }
   }
 
@@ -253,7 +255,7 @@ public class AggregateOperator extends MultiStageOperator {
     // The literal value here does not matter. We create a dummy literal here just so that the count aggregation
     // has some column to process.
     if (aggArguments.isEmpty()) {
-      aggArguments.add(ExpressionContext.forIdentifier("__PLACEHOLDER__"));
+      aggArguments.add(ExpressionContext.forLiteralContext(FieldSpec.DataType.STRING, "__PLACEHOLDER__"));
     }
 
     return new FunctionContext(FunctionContext.Type.AGGREGATION, functionName, aggArguments);
@@ -320,8 +322,8 @@ public class AggregateOperator extends MultiStageOperator {
   // TODO: If the previous block is not mailbox received, this method is not efficient.  Then getDataBlock() will
   //  convert the unserialized format to serialized format of BaseDataBlock. Then it will convert it back to column
   //  value primitive type.
-  static Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggFunction,
-      TransferableBlock block, DataSchema inputDataSchema, Map<String, Integer> colNameToIndexMap) {
+  static Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggFunction, TransferableBlock block,
+      DataSchema inputDataSchema, Map<String, Integer> colNameToIndexMap, int filterArgIdx) {
     List<ExpressionContext> expressions = aggFunction.getInputExpressions();
     int numExpressions = expressions.size();
     if (numExpressions == 0) {
@@ -330,17 +332,34 @@ public class AggregateOperator extends MultiStageOperator {
 
     Map<ExpressionContext, BlockValSet> blockValSetMap = new HashMap<>();
     for (ExpressionContext expression : expressions) {
-      if (expression.getType().equals(ExpressionContext.Type.IDENTIFIER)
-          && !"__PLACEHOLDER__".equals(expression.getIdentifier())) {
+      if (expression.getType().equals(ExpressionContext.Type.IDENTIFIER) && !"__PLACEHOLDER__".equals(
+          expression.getIdentifier())) {
         int index = colNameToIndexMap.get(expression.getIdentifier());
         DataSchema.ColumnDataType dataType = inputDataSchema.getColumnDataType(index);
         Preconditions.checkState(block.getType().equals(DataBlock.Type.ROW), "Datablock type is not ROW");
-        blockValSetMap.put(expression, new IntermediateStageBlockValSet(dataType, block.getDataBlock(), index));
+        if (filterArgIdx == -1) {
+          blockValSetMap.put(expression, new DataBlockValSet(dataType, block.getDataBlock(), index));
+        } else {
+          blockValSetMap.put(expression,
+              new FilteredDataBlockValSet(dataType, block.getDataBlock(), index, filterArgIdx));
+        }
       }
     }
     return blockValSetMap;
   }
 
+  static int computeBlockNumRows(TransferableBlock block, int filterArgIdx) {
+    if (filterArgIdx == -1) {
+      return block.getNumRows();
+    } else {
+      int rowCount = 0;
+      for (int rowId = 0; rowId < block.getNumRows(); rowId++) {
+        rowCount += block.getDataBlock().getInt(rowId, filterArgIdx) == 1 ? 1 : 0;
+      }
+      return rowCount;
+    }
+  }
+
   static Object extractValueFromRow(AggregationFunction aggregationFunction, Object[] row,
       Map<String, Integer> colNameToIndexMap) {
     List<ExpressionContext> expressions = aggregationFunction.getInputExpressions();
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java
index 19e4f66cc6..b352997ac1 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java
@@ -21,6 +21,7 @@ package org.apache.pinot.query.runtime.operator;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import javax.annotation.Nullable;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.common.BlockValSet;
@@ -42,13 +43,15 @@ public class MultistageAggregationExecutor {
   private final DataSchema _resultSchema;
 
   private final AggregationFunction[] _aggFunctions;
+  private final int[] _filterArgIndices;
 
   // Result holders for each mode.
   private final AggregationResultHolder[] _aggregateResultHolder;
   private final Object[] _mergeResultHolder;
 
-  public MultistageAggregationExecutor(AggregationFunction[] aggFunctions,
+  public MultistageAggregationExecutor(AggregationFunction[] aggFunctions, @Nullable int[] filterArgIndices,
       AggType aggType, Map<String, Integer> colNameToIndexMap, DataSchema resultSchema) {
+    _filterArgIndices = filterArgIndices;
     _aggFunctions = aggFunctions;
     _aggType = aggType;
     _colNameToIndexMap = colNameToIndexMap;
@@ -116,11 +119,23 @@ public class MultistageAggregationExecutor {
   }
 
   private void processAggregate(TransferableBlock block, DataSchema inputDataSchema) {
-    for (int i = 0; i < _aggFunctions.length; i++) {
-      AggregationFunction aggregationFunction = _aggFunctions[i];
-      Map<ExpressionContext, BlockValSet> blockValSetMap =
-          AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap);
-      aggregationFunction.aggregate(block.getNumRows(), _aggregateResultHolder[i], blockValSetMap);
+    if (_filterArgIndices == null) {
+      for (int i = 0; i < _aggFunctions.length; i++) {
+        AggregationFunction aggregationFunction = _aggFunctions[i];
+        Map<ExpressionContext, BlockValSet> blockValSetMap =
+            AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap, -1);
+        aggregationFunction.aggregate(block.getNumRows(), _aggregateResultHolder[i], blockValSetMap);
+      }
+    } else {
+      for (int i = 0; i < _aggFunctions.length; i++) {
+        AggregationFunction aggregationFunction = _aggFunctions[i];
+        int filterArgIdx = _filterArgIndices[i];
+        Map<ExpressionContext, BlockValSet> blockValSetMap =
+            AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap,
+                filterArgIdx);
+        int numRows = AggregateOperator.computeBlockNumRows(block, filterArgIdx);
+        aggregationFunction.aggregate(numRows, _aggregateResultHolder[i], blockValSetMap);
+      }
     }
   }
 
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
index 5eacba025b..e33cc491cf 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
@@ -19,9 +19,11 @@
 package org.apache.pinot.query.runtime.operator;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import javax.annotation.Nullable;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.common.BlockValSet;
@@ -47,6 +49,7 @@ public class MultistageGroupByExecutor {
 
   private final List<ExpressionContext> _groupSet;
   private final AggregationFunction[] _aggFunctions;
+  private final int[] _filterArgIndices;
 
   // Group By Result holders for each mode
   private final GroupByResultHolder[] _aggregateResultHolders;
@@ -57,11 +60,13 @@ public class MultistageGroupByExecutor {
   private final Map<Key, Integer> _groupKeyToIdMap;
 
   public MultistageGroupByExecutor(List<ExpressionContext> groupByExpr, AggregationFunction[] aggFunctions,
-      AggType aggType, Map<String, Integer> colNameToIndexMap, DataSchema resultSchema) {
+      @Nullable int[] filterArgIndices, AggType aggType, Map<String, Integer> colNameToIndexMap,
+      DataSchema resultSchema) {
     _aggType = aggType;
     _colNameToIndexMap = colNameToIndexMap;
     _groupSet = groupByExpr;
     _aggFunctions = aggFunctions;
+    _filterArgIndices = filterArgIndices;
     _resultSchema = resultSchema;
 
     _aggregateResultHolders = new GroupByResultHolder[_aggFunctions.length];
@@ -70,9 +75,9 @@ public class MultistageGroupByExecutor {
     _groupKeyToIdMap = new HashMap<>();
 
     for (int i = 0; i < _aggFunctions.length; i++) {
-      _aggregateResultHolders[i] =
-          _aggFunctions[i].createGroupByResultHolder(InstancePlanMakerImplV2.DEFAULT_MAX_INITIAL_RESULT_HOLDER_CAPACITY,
-              InstancePlanMakerImplV2.DEFAULT_NUM_GROUPS_LIMIT);
+      _aggregateResultHolders[i] = _aggFunctions[i].createGroupByResultHolder(
+          InstancePlanMakerImplV2.DEFAULT_MAX_INITIAL_RESULT_HOLDER_CAPACITY,
+          InstancePlanMakerImplV2.DEFAULT_NUM_GROUPS_LIMIT);
     }
   }
 
@@ -129,15 +134,29 @@ public class MultistageGroupByExecutor {
   }
 
   private void processAggregate(TransferableBlock block, DataSchema inputDataSchema) {
-    int[] intKeys = generateGroupByKeys(block.getContainer());
-
-    for (int i = 0; i < _aggFunctions.length; i++) {
-      AggregationFunction aggregationFunction = _aggFunctions[i];
-      Map<ExpressionContext, BlockValSet> blockValSetMap =
-          AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap);
-      GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
-      groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size());
-      aggregationFunction.aggregateGroupBySV(block.getNumRows(), intKeys, groupByResultHolder, blockValSetMap);
+    if (_filterArgIndices == null) {
+      int[] intKeys = generateGroupByKeys(block.getContainer());
+      for (int i = 0; i < _aggFunctions.length; i++) {
+        AggregationFunction aggregationFunction = _aggFunctions[i];
+        Map<ExpressionContext, BlockValSet> blockValSetMap =
+            AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap, -1);
+        GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
+        groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size());
+        aggregationFunction.aggregateGroupBySV(block.getNumRows(), intKeys, groupByResultHolder, blockValSetMap);
+      }
+    } else {
+      for (int i = 0; i < _aggFunctions.length; i++) {
+        AggregationFunction aggregationFunction = _aggFunctions[i];
+        int filterArgIdx = _filterArgIndices[i];
+        int[] intKeys = generateGroupByKeys(block.getContainer(), filterArgIdx);
+        Map<ExpressionContext, BlockValSet> blockValSetMap =
+            AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap,
+                filterArgIdx);
+        int numRows = AggregateOperator.computeBlockNumRows(block, filterArgIdx);
+        GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
+        groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size());
+        aggregationFunction.aggregateGroupBySV(numRows, intKeys, groupByResultHolder, blockValSetMap);
+      }
     }
   }
 
@@ -191,4 +210,42 @@ public class MultistageGroupByExecutor {
     }
     return rowIntKeys;
   }
+
+  /**
+   * Creates the group by key for each row. Converts the key into a 0-index based int value that can be used by
+   * GroupByAggregationResultHolders used in v1 aggregations.
+   * <p>
+   * Returns the int key for each row.
+   */
+  private int[] generateGroupByKeys(List<Object[]> rows, int filterArgIndex) {
+    int numRows = rows.size();
+    int[] rowIntKeys = new int[numRows];
+    int numKeys = _groupSet.size();
+    if (filterArgIndex == -1) {
+      for (int rowId = 0; rowId < numRows; rowId++) {
+        Object[] row = rows.get(rowId);
+        Object[] keyValues = new Object[numKeys];
+        for (int j = 0; j < numKeys; j++) {
+          keyValues[j] = row[_colNameToIndexMap.get(_groupSet.get(j).getIdentifier())];
+        }
+        Key rowKey = new Key(keyValues);
+        rowIntKeys[rowId] = _groupKeyToIdMap.computeIfAbsent(rowKey, k -> _groupKeyToIdMap.size());
+      }
+      return rowIntKeys;
+    } else {
+      int outRowId = 0;
+      for (int inRowId = 0; inRowId < numRows; inRowId++) {
+        Object[] row = rows.get(inRowId);
+        if ((Boolean) row[filterArgIndex]) {
+          Object[] keyValues = new Object[numKeys];
+          for (int j = 0; j < numKeys; j++) {
+            keyValues[j] = row[_colNameToIndexMap.get(_groupSet.get(j).getIdentifier())];
+          }
+          Key rowKey = new Key(keyValues);
+          rowIntKeys[outRowId++] = _groupKeyToIdMap.computeIfAbsent(rowKey, k -> _groupKeyToIdMap.size());
+        }
+      }
+      return Arrays.copyOfRange(rowIntKeys, 0, outRowId);
+    }
+  }
 }
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/DataBlockValSet.java
similarity index 90%
copy from pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java
copy to pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/DataBlockValSet.java
index 7ddaaf04c9..e1bbc077ef 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/DataBlockValSet.java
@@ -16,13 +16,14 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.core.common;
+package org.apache.pinot.query.runtime.operator.block;
 
 import java.math.BigDecimal;
 import javax.annotation.Nullable;
 import org.apache.pinot.common.datablock.DataBlock;
 import org.apache.pinot.common.datablock.DataBlockUtils;
 import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.BlockValSet;
 import org.apache.pinot.segment.spi.index.reader.Dictionary;
 import org.apache.pinot.spi.data.FieldSpec;
 import org.roaringbitmap.RoaringBitmap;
@@ -34,13 +35,13 @@ import org.roaringbitmap.RoaringBitmap;
  * aggregations using v1 aggregation functions.
  * TODO: Support MV
  */
-public class IntermediateStageBlockValSet implements BlockValSet {
-  private final FieldSpec.DataType _dataType;
-  private final DataBlock _dataBlock;
-  private final int _index;
-  private final RoaringBitmap _nullBitMap;
+public class DataBlockValSet implements BlockValSet {
+  protected final FieldSpec.DataType _dataType;
+  protected final DataBlock _dataBlock;
+  protected final int _index;
+  protected final RoaringBitmap _nullBitMap;
 
-  public IntermediateStageBlockValSet(DataSchema.ColumnDataType columnDataType, DataBlock dataBlock, int colIndex) {
+  public DataBlockValSet(DataSchema.ColumnDataType columnDataType, DataBlock dataBlock, int colIndex) {
     _dataType = columnDataType.toDataType();
     _dataBlock = dataBlock;
     _index = colIndex;
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/FilteredDataBlockValSet.java
similarity index 86%
rename from pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java
rename to pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/FilteredDataBlockValSet.java
index 7ddaaf04c9..e4231fbd71 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/common/IntermediateStageBlockValSet.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/block/FilteredDataBlockValSet.java
@@ -16,7 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.pinot.core.common;
+package org.apache.pinot.query.runtime.operator.block;
 
 import java.math.BigDecimal;
 import javax.annotation.Nullable;
@@ -27,6 +27,7 @@ import org.apache.pinot.segment.spi.index.reader.Dictionary;
 import org.apache.pinot.spi.data.FieldSpec;
 import org.roaringbitmap.RoaringBitmap;
 
+
 /**
  * In the multistage engine, the leaf stage servers process the data in columnar fashion. By the time the
  * intermediate stage receives the projected column, they are converted to a row based format. This class provides
@@ -34,17 +35,13 @@ import org.roaringbitmap.RoaringBitmap;
  * aggregations using v1 aggregation functions.
  * TODO: Support MV
  */
-public class IntermediateStageBlockValSet implements BlockValSet {
-  private final FieldSpec.DataType _dataType;
-  private final DataBlock _dataBlock;
-  private final int _index;
-  private final RoaringBitmap _nullBitMap;
+public class FilteredDataBlockValSet extends DataBlockValSet {
+  private final int _filterIdx;
 
-  public IntermediateStageBlockValSet(DataSchema.ColumnDataType columnDataType, DataBlock dataBlock, int colIndex) {
-    _dataType = columnDataType.toDataType();
-    _dataBlock = dataBlock;
-    _index = colIndex;
-    _nullBitMap = dataBlock.getNullRowIds(colIndex);
+  public FilteredDataBlockValSet(DataSchema.ColumnDataType columnDataType, DataBlock dataBlock, int colIndex,
+      int filterIdx) {
+    super(columnDataType, dataBlock, colIndex);
+    _filterIdx = filterIdx;
   }
 
   /**
@@ -80,37 +77,37 @@ public class IntermediateStageBlockValSet implements BlockValSet {
 
   @Override
   public int[] getIntValuesSV() {
-    return DataBlockUtils.extractIntValuesForColumn(_dataBlock, _index);
+    return DataBlockUtils.extractIntValuesForColumn(_dataBlock, _index, _filterIdx);
   }
 
   @Override
   public long[] getLongValuesSV() {
-    return DataBlockUtils.extractLongValuesForColumn(_dataBlock, _index);
+    return DataBlockUtils.extractLongValuesForColumn(_dataBlock, _index, _filterIdx);
   }
 
   @Override
   public float[] getFloatValuesSV() {
-    return DataBlockUtils.extractFloatValuesForColumn(_dataBlock, _index);
+    return DataBlockUtils.extractFloatValuesForColumn(_dataBlock, _index, _filterIdx);
   }
 
   @Override
   public double[] getDoubleValuesSV() {
-    return DataBlockUtils.extractDoubleValuesForColumn(_dataBlock, _index);
+    return DataBlockUtils.extractDoubleValuesForColumn(_dataBlock, _index, _filterIdx);
   }
 
   @Override
   public BigDecimal[] getBigDecimalValuesSV() {
-    return DataBlockUtils.extractBigDecimalValuesForColumn(_dataBlock, _index);
+    return DataBlockUtils.extractBigDecimalValuesForColumn(_dataBlock, _index, _filterIdx);
   }
 
   @Override
   public String[] getStringValuesSV() {
-    return DataBlockUtils.extractStringValuesForColumn(_dataBlock, _index);
+    return DataBlockUtils.extractStringValuesForColumn(_dataBlock, _index, _filterIdx);
   }
 
   @Override
   public byte[][] getBytesValuesSV() {
-    return DataBlockUtils.extractBytesValuesForColumn(_dataBlock, _index);
+    return DataBlockUtils.extractBytesValuesForColumn(_dataBlock, _index, _filterIdx);
   }
 
   @Override
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
index 79f2275274..2340545437 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/PhysicalPlanVisitor.java
@@ -101,7 +101,7 @@ public class PhysicalPlanVisitor implements PlanNodeVisitor<MultiStageOperator,
     DataSchema resultSchema = node.getDataSchema();
 
     return new AggregateOperator(context.getOpChainExecutionContext(), nextOperator, resultSchema, inputSchema,
-        node.getAggCalls(), node.getGroupSet(), node.getAggType(), node.getNodeHint());
+        node.getAggCalls(), node.getGroupSet(), node.getAggType(), node.getFilterArgIndices(), node.getNodeHint());
   }
 
   @Override
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java
index 8fa6df2756..5e3873fbcd 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/plan/server/ServerPlanRequestVisitor.java
@@ -71,8 +71,8 @@ public class ServerPlanRequestVisitor implements PlanNodeVisitor<Void, ServerPla
         .setGroupByList(CalciteRexExpressionParser.convertGroupByList(node.getGroupSet(), context.getPinotQuery()));
     // set agg list
     context.getPinotQuery().setSelectList(
-        CalciteRexExpressionParser.addSelectList(context.getPinotQuery().getGroupByList(), node.getAggCalls(),
-            context.getPinotQuery()));
+        CalciteRexExpressionParser.convertAggregateList(context.getPinotQuery().getGroupByList(), node.getAggCalls(),
+            node.getFilterArgIndices(), context.getPinotQuery()));
     return null;
   }
 
@@ -149,7 +149,7 @@ public class ServerPlanRequestVisitor implements PlanNodeVisitor<Void, ServerPla
   public Void visitProject(ProjectNode node, ServerPlanRequestContext context) {
     visitChildren(node, context);
     context.getPinotQuery()
-        .setSelectList(CalciteRexExpressionParser.overwriteSelectList(node.getProjects(), context.getPinotQuery()));
+        .setSelectList(CalciteRexExpressionParser.convertProjectList(node.getProjects(), context.getPinotQuery()));
     return null;
   }
 
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
index 8048bad5f4..d32d38b2b0 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
@@ -79,7 +79,7 @@ public class AggregateOperatorTest {
     DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
     AggregateOperator operator =
         new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
-            AggType.INTERMEDIATE);
+            AggType.INTERMEDIATE, null, null);
 
     // When:
     TransferableBlock block1 = operator.nextBlock(); // build
@@ -101,7 +101,7 @@ public class AggregateOperatorTest {
     DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
     AggregateOperator operator =
         new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
-            AggType.LEAF);
+            AggType.LEAF, null, null);
 
     // When:
     TransferableBlock block = operator.nextBlock();
@@ -125,7 +125,7 @@ public class AggregateOperatorTest {
     DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
     AggregateOperator operator =
         new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
-            AggType.LEAF);
+            AggType.LEAF, null, null);
 
     // When:
     TransferableBlock block1 = operator.nextBlock(); // build when reading NoOp block
@@ -150,7 +150,7 @@ public class AggregateOperatorTest {
     DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
     AggregateOperator operator =
         new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
-            AggType.INTERMEDIATE);
+            AggType.INTERMEDIATE, null, null);
 
     // When:
     TransferableBlock block1 = operator.nextBlock();
@@ -177,7 +177,7 @@ public class AggregateOperatorTest {
     DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, LONG});
     AggregateOperator operator =
         new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
-            AggType.FINAL);
+            AggType.FINAL, null, null);
 
     // When:
     TransferableBlock block1 = operator.nextBlock();
@@ -201,7 +201,7 @@ public class AggregateOperatorTest {
     DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{STRING, DOUBLE});
     AggregateOperator sum0GroupBy1 = new AggregateOperator(OperatorTestUtil.getDefaultContext(), upstreamOperator,
         outSchema, inSchema, Collections.singletonList(agg),
-        Collections.singletonList(new RexExpression.InputRef(1)), AggType.LEAF);
+        Collections.singletonList(new RexExpression.InputRef(1)), AggType.LEAF, null, null);
     TransferableBlock result = sum0GroupBy1.getNextBlock();
     while (result.isNoOpBlock()) {
       result = sum0GroupBy1.getNextBlock();
@@ -229,7 +229,7 @@ public class AggregateOperatorTest {
     // When:
     AggregateOperator operator =
         new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
-            AggType.INTERMEDIATE);
+            AggType.INTERMEDIATE, null, null);
   }
 
   @Test
@@ -248,7 +248,7 @@ public class AggregateOperatorTest {
     DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
     AggregateOperator operator =
         new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
-            AggType.INTERMEDIATE);
+            AggType.INTERMEDIATE, null, null);
 
     // When:
     TransferableBlock block = operator.nextBlock();
diff --git a/pinot-query-runtime/src/test/resources/queries/FilterAggregates.json b/pinot-query-runtime/src/test/resources/queries/FilterAggregates.json
new file mode 100644
index 0000000000..af308031ab
--- /dev/null
+++ b/pinot-query-runtime/src/test/resources/queries/FilterAggregates.json
@@ -0,0 +1,166 @@
+{
+  "general_aggregate_with_filter_where": {
+    "tables": {
+      "tbl": {
+        "schema": [
+          {"name": "int_col", "type": "INT"},
+          {"name": "double_col", "type": "DOUBLE"},
+          {"name": "string_col", "type": "STRING"},
+          {"name": "bool_col", "type": "BOOLEAN"}
+        ],
+        "inputs": [
+          [2, 300, "a", true],
+          [2, 400, "a", true],
+          [3, 100, "b", false],
+          [100, 1, "b", false],
+          [101, 1.01, "c", false],
+          [150, 1.5, "c", false],
+          [175, 1.75, "c", true]
+        ]
+      }
+    },
+    "queries": [
+      { 
+        "ignored": true,
+        "comments": "FILTER WHERE clause with IN hard-wired to translate into subquery, which in this case should not happen.",
+        "sql": "SELECT min(double_col) FILTER (WHERE string_col IN ('a', 'b')), count(*) FROM {tbl}"
+      },
+      {
+        "ignored": true,
+        "comments": "IS NULL and IS NOT NULL is not yet supported in filter conversion.",
+        "sql": "SELECT min(double_col) FILTER (WHERE string_col IS NOT NULL), count(*) FROM {tbl}"
+      },
+      {
+        "ignored": true,
+        "comments": "agg with filter and group-by causes conversion issue on v1 if the group-by field is not in the select list",
+        "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col > 10) FROM {tbl} GROUP BY int_col"
+      },
+      {
+        "ignored": true,
+        "comments": "agg with group by and filter will create NULL-able columns that are unsupported with current AGG FILTER WHERE semantics.",
+        "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col > 10), count(*) FROM {tbl} GROUP BY int_col, string_col"
+      },
+      {
+        "ignored": true,
+        "comments": "mixed/conflict filter that requires merging in v1 is not supported",
+        "sql": "SELECT double_col, bool_col, count(int_col) FILTER (WHERE string_col = 'a' OR int_col > 10) FROM {tbl} WHERE string_col = 'b' GROUP BY double_col, bool_col"
+      },
+      {
+        "ignored": true,
+        "comments": "FILTER WHERE clause might omit group key entirely if nothing is being selected out, this is non-standard SQL behavior but it is v1 behavior",
+        "sql": "SELECT int_col, count(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10) FROM {tbl} GROUP BY int_col"
+      },
+      { "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col > 10) FROM {tbl}" },
+      { "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col > 10), count(*) FROM {tbl}" },
+      { "sql": "SELECT min(int_col) FILTER (WHERE bool_col IS TRUE), max(int_col) FILTER (WHERE bool_col AND int_col < 10), avg(int_col) FILTER (WHERE MOD(int_col, 3) = 0), sum(int_col), count(int_col), count(distinct(int_col)), count(*) FILTER (WHERE MOD(int_col, 3) = 0) FROM {tbl}" },
+      { "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col > 10) FROM {tbl} WHERE string_col='b'" },
+      { "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col > 10), count(*) FROM {tbl} WHERE string_col='b'" },
+      { "sql": "SELECT int_col, COALESCE(count(double_col) FILTER (WHERE string_col = 'a' OR int_col > 0), 0), count(*) FROM {tbl} GROUP BY int_col" },
+      {
+        "ignored": true,
+        "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched",
+        "sql": "SELECT int_col, string_col, COALESCE(min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), 0), COALESCE(max(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10), 0), avg(double_col), sum(double_col), count(double_col), COALESCE(count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col > 10), 0) FROM {tbl} GROUP BY int_col, string_col"
+      },
+      {
+        "ignored": true,
+        "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched",
+        "sql": "SELECT double_col, COALESCE(min(int_col) FILTER (WHERE bool_col IS TRUE), 0), COALESCE(max(int_col) FILTER (WHERE bool_col AND int_col < 10), 0), COALESCE(avg(int_col) FILTER (WHERE MOD(int_col, 3) = 0), 0), sum(int_col), count(int_col), count(distinct(int_col)), count(string_col) FILTER (WHERE MOD(int_col, 3) = 0) FROM {tbl} GROUP BY double_col"
+      },
+      { "sql": "SELECT double_col, bool_col, count(int_col) FILTER (WHERE string_col = 'a' OR int_col > 10), count(int_col) FROM {tbl} WHERE string_col IN ('a', 'b') GROUP BY double_col, bool_col" },
+      {
+        "ignored": true,
+        "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched",
+        "sql": "SELECT bool_col, COALESCE(min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), 0), COALESCE(max(double_col) FILTER (WHERE string_col = 'a' OR int_col > 10), 0), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col > 10), count(string_col) FROM {tbl} WHERE string_col='b' GROUP BY bool_col"
+      }
+    ]
+  },
+  "general_aggregate_with_filter_where_after_join": {
+    "tables": {
+      "tbl1": {
+        "schema": [
+          {"name": "int_col", "type": "INT"},
+          {"name": "double_col", "type": "DOUBLE"},
+          {"name": "string_col", "type": "STRING"},
+          {"name": "bool_col", "type": "BOOLEAN"}
+        ],
+        "inputs": [
+          [2, 300, "a", true],
+          [2, 400, "a", true],
+          [3, 100, "b", false],
+          [100, 1, "b", false],
+          [101, 1.01, "c", false],
+          [150, 1.5, "c", false],
+          [175, 1.75, "c", true]
+        ]
+      },
+      "tbl2": {
+        "schema":[
+          {"name": "int_col2", "type": "INT"},
+          {"name": "string_col2", "type": "STRING"},
+          {"name":  "double_col2", "type":  "DOUBLE"}
+        ],
+        "inputs": [
+          [1, "apple", 1000.0],
+          [2, "a", 1.323],
+          [3, "b", 1212.12],
+          [3, "c", 341],
+          [4, "orange", 1212.121]
+        ]
+      }
+    },
+    "queries": [
+      {
+        "ignored": true,
+        "comments": "FILTER WHERE clause with IN hard-wired to translate into subquery, which in this case should not happen.",
+        "sql": "SELECT min(double_col) FILTER (WHERE string_col IN ('a', 'b')), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2"
+      },
+      {
+        "ignored": true,
+        "comments": "IS NULL and IS NOT NULL is not yet supported in filter conversion.",
+        "sql": "SELECT min(double_col) FILTER (WHERE string_col IS NOT NULL), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2"
+      },
+      {
+        "ignored": true,
+        "comments": "agg with filter and group-by causes conversion issue on v1 if the group-by field is not in the select list",
+        "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col2 > 10) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY int_col2"
+      },
+      {
+        "ignored": true,
+        "comments": "agg with group by and filter will create NULL-able columns that are unsupported with current AGG FILTER WHERE semantics.",
+        "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col2 > 10), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY int_col2, string_col"
+      },
+      {
+        "ignored": true,
+        "comments": "mixed/conflict filter that requires merging in v1 is not supported",
+        "sql": "SELECT double_col, bool_col, count(int_col2) FILTER (WHERE string_col = 'a' OR int_col2 > 10) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 WHERE string_col = 'b' GROUP BY double_col, bool_col"
+      },
+      {
+        "ignored": true,
+        "comments": "FILTER WHERE clause might omit group key entirely if nothing is being selected out, this is non-standard SQL behavior but it is v1 behavior",
+        "sql": "SELECT int_col2, count(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY int_col2"
+      },
+      { "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col2 > 10) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2" },
+      { "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col2 > 10), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2" },
+      { "sql": "SELECT min(int_col2) FILTER (WHERE bool_col IS TRUE), max(int_col2) FILTER (WHERE bool_col AND int_col2 < 10), avg(int_col2) FILTER (WHERE MOD(int_col2, 3) = 0), sum(int_col2), count(int_col2), count(distinct(int_col2)), count(*) FILTER (WHERE MOD(int_col2, 3) = 0) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2" },
+      { "sql": "SELECT count(*) FILTER (WHERE string_col = 'a' OR int_col2 > 10) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 WHERE string_col='b'" },
+      { "sql": "SELECT min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), max(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col2 > 10), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 WHERE string_col='b'" },
+      { "sql": "SELECT int_col2, COALESCE(count(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 0), 0), count(*) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY int_col2" },
+      {
+        "ignored": true,
+        "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched",
+        "sql": "SELECT int_col2, string_col, COALESCE(min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), 0), COALESCE(max(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10), 0), avg(double_col), sum(double_col), count(double_col), COALESCE(count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col2 > 10), 0) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY int_col2, string_col"
+      },
+      {
+        "ignored": true,
+        "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched",
+        "sql": "SELECT double_col, COALESCE(min(int_col2) FILTER (WHERE bool_col IS TRUE), 0), COALESCE(max(int_col2) FILTER (WHERE bool_col AND int_col2 < 10), 0), COALESCE(avg(int_col2) FILTER (WHERE MOD(int_col2, 3) = 0), 0), sum(int_col2), count(int_col2), count(distinct(int_col2)), count(string_col) FILTER (WHERE MOD(int_col2, 3) = 0) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 GROUP BY double_col"
+      },
+      { "sql": "SELECT double_col, bool_col, count(int_col2) FILTER (WHERE string_col = 'a' OR int_col2 > 10), count(int_col2) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 WHERE string_col IN ('a', 'b') GROUP BY double_col, bool_col" },
+      {
+        "ignored": true,
+        "comments": "Calcite limitation on SQL type inference and Relational type inference has mismatched info (regarding filterArg existent, thus nullability mismatched",
+        "sql": "SELECT bool_col, COALESCE(min(double_col) FILTER (WHERE string_col = 'a' OR string_col = 'b'), 0), COALESCE(max(double_col) FILTER (WHERE string_col = 'a' OR int_col2 > 10), 0), avg(double_col), sum(double_col), count(double_col), count(distinct(double_col)) FILTER (WHERE string_col = 'b' OR int_col2 > 10), count(string_col) FROM {tbl1} JOIN {tbl2} ON string_col = string_col2 WHERE string_col='b' GROUP BY bool_col"
+      }
+    ]
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org