You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ro...@apache.org on 2023/07/11 02:49:06 UTC

[pinot] branch master updated: [multistage] Fix aggregate type issues (#11068)

This is an automated email from the ASF dual-hosted git repository.

rongr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 19dce0e282 [multistage] Fix aggregate type issues (#11068)
19dce0e282 is described below

commit 19dce0e282ccd40f3da0fa2c8168efbd812f60fd
Author: Rong Rong <ro...@apache.org>
AuthorDate: Mon Jul 10 19:49:00 2023 -0700

    [multistage] Fix aggregate type issues (#11068)
    
    1. cleaned up on type issues for agg
        - pinot sql agg functions are no longer pre-registered with calcite which caused type   during SqlValidator
        - dynamic pinot sql agg function creating is being used to avoid type inference not applying properly (such as inferring from ARG0 cannot be applied statically b/c then ARG0 refers to the input to the intermediate type, not the original DIRECT type)
    2. revisit intermediate type inference when registering built-in agg (such as MIN/MAX) instead of doing raw casting from number/boolean to double/int when extracting value from input blocks
        - follow up by registering explicit intermediate result type
        - also follow up when `ReturnType.explicit(*)` is not possible in final result type, by working around from available validation type from RelBuilder rule call object.
    3. fixed boolean handling (except bool_and/bool_or which is handled separately by #11033)
    4. re-adjusted direct project rule, by using calcite AGG_PROJECT_EXPAND template
        - follow up needed to ensure the direct aggregate, (see TODO#1, 2)
    5. re-adjusted agg-exchange-rule
        - no longer generates multiple agg now we only generate DIRECT/LEAF/FINAL (INTERMEDIATE not supported yet)
        - un-generalized the rule NOT to reuse convert-agg-node and build-agg-call. They don't share logic, thus, the generic causes more problems with various nullable fields.
        - removed unnecessary logic for discovering field-type/field-name/field-used, which already exists in calcite RelOptUtils
    6. modified test plan generated, removed unnecessary additional agg
    
    ---------
    
    Co-authored-by: Rong Rong <ro...@startree.ai>
---
 .../PinotAggregateExchangeNodeInsertRule.java      | 464 ++++++++-------------
 .../rules/PinotAggregateReduceFunctionsRule.java   | 427 -------------------
 .../calcite/rel/rules/PinotQueryRuleSets.java      |   2 +-
 .../apache/calcite/sql/fun/PinotOperatorTable.java | 140 +------
 .../apache/pinot/query/QueryCompilationTest.java   |   1 -
 .../pinot/query/QueryEnvironmentTestBase.java      |   5 +-
 .../src/test/resources/queries/AggregatePlans.json | 183 ++------
 .../src/test/resources/queries/GroupByPlans.json   | 308 ++++----------
 .../src/test/resources/queries/JoinPlans.json      |  77 ++--
 .../src/test/resources/queries/OrderByPlans.json   |  32 +-
 .../test/resources/queries/PinotHintablePlans.json |  64 +--
 .../resources/queries/WindowFunctionPlans.json     | 228 ++++------
 .../query/runtime/operator/AggregateOperator.java  |  47 ++-
 .../query/runtime/operator/FilterOperator.java     |   4 +-
 .../query/runtime/operator/HashJoinOperator.java   |   4 +-
 .../LeafStageTransferableBlockOperator.java        |  44 +-
 .../operator/MultistageAggregationExecutor.java    | 109 ++---
 .../operator/MultistageGroupByExecutor.java        | 128 ++----
 .../query/runtime/operator/TransformOperator.java  |   5 +-
 .../runtime/operator/operands/FilterOperand.java   |   6 +-
 .../operator/utils/FunctionInvokeUtils.java        |  45 --
 .../query/runtime/operator/utils/TypeUtils.java    |  88 ++++
 .../runtime/operator/AggregateOperatorTest.java    |  40 +-
 .../pinot/segment/spi/AggregationFunctionType.java | 186 ++++-----
 24 files changed, 786 insertions(+), 1851 deletions(-)

diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
index 199619d14b..7792a0b4e0 100644
--- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
+++ b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateExchangeNodeInsertRule.java
@@ -26,7 +26,6 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
-import javax.annotation.Nullable;
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
 import org.apache.calcite.plan.hep.HepRelVertex;
@@ -39,49 +38,45 @@ import org.apache.calcite.rel.hint.PinotHintOptions;
 import org.apache.calcite.rel.hint.PinotHintStrategyTable;
 import org.apache.calcite.rel.hint.RelHint;
 import org.apache.calcite.rel.logical.LogicalAggregate;
-import org.apache.calcite.rel.logical.LogicalProject;
 import org.apache.calcite.rel.logical.PinotLogicalExchange;
-import org.apache.calcite.rel.type.RelDataTypeField;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.sql.PinotSqlAggFunction;
 import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.fun.PinotOperatorTable;
-import org.apache.calcite.sql.type.SqlReturnTypeInference;
+import org.apache.calcite.sql.type.OperandTypes;
+import org.apache.calcite.sql.type.ReturnTypes;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.tools.RelBuilderFactory;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.ImmutableIntList;
+import org.apache.calcite.util.Util;
+import org.apache.calcite.util.mapping.Mapping;
+import org.apache.calcite.util.mapping.MappingType;
+import org.apache.calcite.util.mapping.Mappings;
 import org.apache.pinot.query.planner.plannode.AggregateNode.AggType;
 import org.apache.pinot.segment.spi.AggregationFunctionType;
 
 
 /**
- * Special rule for Pinot, this rule is fixed to generate a 3-stage aggregation split between the
- * (1) non-data-locale Pinot server agg stage, (2) the data-locale Pinot intermediate agg stage, and
- * (3) final result Pinot final agg stage.
+ * Special rule for Pinot, this rule is fixed to generate a 2-stage aggregation split between the
+ * (1) non-data-locale Pinot server agg stage, and (2) the data-locale Pinot intermediate agg stage.
  *
  * Pinot uses special intermediate data representation for partially aggregated results, thus we can't use
  * {@link org.apache.calcite.rel.rules.AggregateReduceFunctionsRule} to reduce complex aggregation.
  *
- * This rule is here to introduces Pinot-special aggregation splits. In-general, all aggregations are split into
- * final-stage AGG, intermediate-stage AGG, and server-stage AGG with the same naming. E.g.
+ * This rule is here to introduces Pinot-special aggregation splits. In-general there are several options:
+ * <ul>
+ *   <li>`aggName`__DIRECT</li>
+ *   <li>`aggName`__LEAF + `aggName`__FINAL</li>
+ *   <li>`aggName`__LEAF [+ `aggName`__INTERMEDIATE] + `aggName`__FINAL</li>
+ * </ul>
  *
- * COUNT(*) transforms into: COUNT(*)_SERVER --> COUNT(*)_INTERMEDIATE --> COUNT(*)_FINAL, where
- *   COUNT(*)_SERVER produces TUPLE[ COUNT(data), GROUP_BY_KEY ]
- *   COUNT(*)_INTERMEDIATE produces TUPLE[ SUM(COUNT(*)_SERVER), GROUP_BY_KEY ] (intermediate result here is the count)
- *   COUNT(*)_FINAL produces the final TUPLE[ FINAL_COUNT, GROUP_BY_KEY ]
- *
- * Taking an example of a function which has a different intermediate object representation than the final result:
- * KURTOSIS(*) transforms into: 4THMOMENT(*)_SERVER --> 4THMOMENT(*)_INTERMEDIATE --> KURTOSIS(*)_FINAL, where
- *   FOURTHMOMENT(*)_SERVER produces TUPLE[ 4THMOMENT(data) object, GROUP_BY_KEY ] (input: rowType, output: object)
- *   FOURTHMOMENT(*)_INTERMEDIATE produces TUPLE[ 4THMOMENT(4THMOMENT(*)_SERVER), GROUP_BY_KEY ] (input, output: object)
- *   KURTOSIS(*)_FINAL produces the final TUPLE[ KURTOSIS(4THMOMENT(*)_INTERMEDIATE), GROUP_BY_KEY ]
- *     (input: object, output: double)
- *
- * However, the suffix _SERVER/_INTERMEDIATE/_FINAL is merely a SQL hint to the Aggregate operator and will be
- * translated into correct, actual operator chain during Physical plan.
+ * for example:
+ * - COUNT(*) with a GROUP_BY_KEY transforms into: COUNT(*)__LEAF --> COUNT(*)__FINAL, where
+ *   - COUNT(*)__LEAF produces TUPLE[ SUM(1), GROUP_BY_KEY ]
+ *   - COUNT(*)__FINAL produces TUPLE[ SUM(COUNT(*)__LEAF), GROUP_BY_KEY ]
  */
 public class PinotAggregateExchangeNodeInsertRule extends RelOptRule {
   public static final PinotAggregateExchangeNodeInsertRule INSTANCE =
@@ -117,336 +112,241 @@ public class PinotAggregateExchangeNodeInsertRule extends RelOptRule {
     Aggregate oldAggRel = call.rel(0);
     ImmutableList<RelHint> oldHints = oldAggRel.getHints();
 
-    // If the "is_partitioned_by_group_by_keys" aggregate hint option is set, just add additional hints indicating
-    // this is a single stage aggregation and intermediate stage. This only applies to GROUP BY aggregations.
+    Aggregate newAgg;
     if (!oldAggRel.getGroupSet().isEmpty() && PinotHintStrategyTable.containsHintOption(oldHints,
         PinotHintOptions.AGGREGATE_HINT_OPTIONS, PinotHintOptions.AggregateOptions.IS_PARTITIONED_BY_GROUP_BY_KEYS)) {
-      // 1. attach intermediate agg and skip leaf stage RelHints to original agg.
+      // ------------------------------------------------------------------------
+      // If the "is_partitioned_by_group_by_keys" aggregate hint option is set, just add additional hints indicating
+      // this is a single stage aggregation.
       ImmutableList<RelHint> newLeafAggHints =
           new ImmutableList.Builder<RelHint>().addAll(oldHints).add(createAggHint(AggType.DIRECT)).build();
-      Aggregate newAgg =
+      newAgg =
           new LogicalAggregate(oldAggRel.getCluster(), oldAggRel.getTraitSet(), newLeafAggHints, oldAggRel.getInput(),
               oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), oldAggRel.getAggCallList());
-      call.transformTo(newAgg);
-      return;
-    }
-
-    // If "is_skip_leaf_stage_group_by" SQLHint option is passed, the leaf stage aggregation is skipped.
-    if (!oldAggRel.getGroupSet().isEmpty() && PinotHintStrategyTable.containsHintOption(oldHints,
+    } else if (!oldAggRel.getGroupSet().isEmpty() && PinotHintStrategyTable.containsHintOption(oldHints,
         PinotHintOptions.AGGREGATE_HINT_OPTIONS,
         PinotHintOptions.AggregateOptions.SKIP_LEAF_STAGE_GROUP_BY_AGGREGATION)) {
-      // This is not the default path. Use this group by optimization to skip leaf stage aggregation when aggregating
-      // at leaf level could be wasted effort. eg: when cardinality of group by column is very high.
-      Aggregate newAgg = (Aggregate) createPlanWithoutLeafAggregation(call);
-      call.transformTo(newAgg);
-      return;
+      // ------------------------------------------------------------------------
+      // If "is_skip_leaf_stage_group_by" SQLHint option is passed, the leaf stage aggregation is skipped.
+      newAgg = (Aggregate) createPlanWithExchangeDirectAggregation(call);
+    } else {
+      // ------------------------------------------------------------------------
+      newAgg = (Aggregate) createPlanWithLeafExchangeFinalAggregate(call);
     }
+    call.transformTo(newAgg);
+  }
 
+  /**
+   * Aggregate node will be split into LEAF + exchange + FINAL.
+   * optionally we can insert INTERMEDIATE to reduce hotspot in the future.
+   */
+  private RelNode createPlanWithLeafExchangeFinalAggregate(RelOptRuleCall call) {
+    // TODO: add optional intermediate stage here when hinted.
+    Aggregate oldAggRel = call.rel(0);
     // 1. attach leaf agg RelHint to original agg. Perform any aggregation call conversions necessary
-    ImmutableList<RelHint> newLeafAggHints =
-        new ImmutableList.Builder<RelHint>().addAll(oldHints).add(createAggHint(AggType.LEAF)).build();
-    Aggregate newLeafAgg =
-        new LogicalAggregate(oldAggRel.getCluster(), oldAggRel.getTraitSet(), newLeafAggHints, oldAggRel.getInput(),
-            oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), convertLeafAggCalls(oldAggRel));
-
+    Aggregate leafAgg = convertAggForLeafInput(oldAggRel);
     // 2. attach exchange.
     List<Integer> groupSetIndices = ImmutableIntList.range(0, oldAggRel.getGroupCount());
     PinotLogicalExchange exchange;
     if (groupSetIndices.size() == 0) {
-      exchange = PinotLogicalExchange.create(newLeafAgg, RelDistributions.hash(Collections.emptyList()));
+      exchange = PinotLogicalExchange.create(leafAgg, RelDistributions.hash(Collections.emptyList()));
     } else {
-      exchange = PinotLogicalExchange.create(newLeafAgg, RelDistributions.hash(groupSetIndices));
+      exchange = PinotLogicalExchange.create(leafAgg, RelDistributions.hash(groupSetIndices));
     }
+    // 3. attach final agg stage.
+    return convertAggFromIntermediateInput(call, oldAggRel, exchange, AggType.FINAL);
+  }
 
-    // 3. attach intermediate agg stage.
-    RelNode newAggNode = makeNewIntermediateAgg(call, oldAggRel, exchange, AggType.INTERMEDIATE, null, null);
+  /**
+   * Use this group by optimization to skip leaf stage aggregation when aggregating at leaf level is not desired.
+   * Many situation could be wasted effort to do group-by on leaf, eg: when cardinality of group by column is very high.
+   */
+  private RelNode createPlanWithExchangeDirectAggregation(RelOptRuleCall call) {
+    Aggregate oldAggRel = call.rel(0);
+    ImmutableList<RelHint> oldHints = oldAggRel.getHints();
+    ImmutableList<RelHint> newHints =
+        new ImmutableList.Builder<RelHint>().addAll(oldHints).add(createAggHint(AggType.DIRECT)).build();
 
-    // 4. attach final agg stage if aggregations are present.
-    RelNode transformToAgg = newAggNode;
-    if (oldAggRel.getAggCallList() != null && oldAggRel.getAggCallList().size() > 0) {
-      transformToAgg = makeNewFinalAgg(call, oldAggRel, newAggNode);
+    // create project when there's none below the aggregate to reduce exchange overhead
+    RelNode childRel = ((HepRelVertex) oldAggRel.getInput()).getCurrentRel();
+    if (!(childRel instanceof Project)) {
+      return convertAggForExchangeDirectAggregate(call, newHints);
+    } else {
+      // create normal exchange
+      List<Integer> groupSetIndices = new ArrayList<>();
+      oldAggRel.getGroupSet().forEach(groupSetIndices::add);
+      PinotLogicalExchange exchange = PinotLogicalExchange.create(childRel, RelDistributions.hash(groupSetIndices));
+      return new LogicalAggregate(oldAggRel.getCluster(), oldAggRel.getTraitSet(), newHints, exchange,
+          oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), oldAggRel.getAggCallList());
     }
+  }
 
-    call.transformTo(transformToAgg);
+  /**
+   * The following is copied from {@link AggregateExtractProjectRule#onMatch(RelOptRuleCall)}
+   * with modification to insert an exchange in between the Aggregate and Project
+   */
+  private RelNode convertAggForExchangeDirectAggregate(RelOptRuleCall call, ImmutableList<RelHint> newHints) {
+    final Aggregate aggregate = call.rel(0);
+    final RelNode input = aggregate.getInput();
+    // Compute which input fields are used.
+    // 1. group fields are always used
+    final ImmutableBitSet.Builder inputFieldsUsed =
+        aggregate.getGroupSet().rebuild();
+    // 2. agg functions
+    for (AggregateCall aggCall : aggregate.getAggCallList()) {
+      for (int i : aggCall.getArgList()) {
+        inputFieldsUsed.set(i);
+      }
+      if (aggCall.filterArg >= 0) {
+        inputFieldsUsed.set(aggCall.filterArg);
+      }
+    }
+    final RelBuilder relBuilder1 = call.builder().push(input);
+    final List<RexNode> projects = new ArrayList<>();
+    final Mapping mapping =
+        Mappings.create(MappingType.INVERSE_SURJECTION,
+            aggregate.getInput().getRowType().getFieldCount(),
+            inputFieldsUsed.cardinality());
+    int j = 0;
+    for (int i : inputFieldsUsed.build()) {
+      projects.add(relBuilder1.field(i));
+      mapping.set(i, j++);
+    }
+    relBuilder1.project(projects);
+    final ImmutableBitSet newGroupSet =
+        Mappings.apply(mapping, aggregate.getGroupSet());
+    Project project = (Project) relBuilder1.build();
+
+    // ------------------------------------------------------------------------
+    PinotLogicalExchange exchange = PinotLogicalExchange.create(project, RelDistributions.hash(newGroupSet.asList()));
+    // ------------------------------------------------------------------------
+
+    final RelBuilder relBuilder2 = call.builder().push(exchange);
+    final List<ImmutableBitSet> newGroupSets =
+        aggregate.getGroupSets().stream()
+            .map(bitSet -> Mappings.apply(mapping, bitSet))
+            .collect(Util.toImmutableList());
+    final List<RelBuilder.AggCall> newAggCallList =
+        aggregate.getAggCallList().stream()
+            .map(aggCall -> relBuilder2.aggregateCall(aggCall, mapping))
+            .collect(Util.toImmutableList());
+    final RelBuilder.GroupKey groupKey =
+        relBuilder2.groupKey(newGroupSet, newGroupSets);
+    relBuilder2.aggregate(groupKey, newAggCallList).hints(newHints);
+    return relBuilder2.build();
   }
 
-  private List<AggregateCall> convertLeafAggCalls(Aggregate oldAggRel) {
+  private Aggregate convertAggForLeafInput(Aggregate oldAggRel) {
     List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
     List<AggregateCall> newCalls = new ArrayList<>();
     for (AggregateCall oldCall : oldCalls) {
-      newCalls.add(buildAggregateCall(oldAggRel.getInput(), oldCall.getArgList(), oldAggRel.getGroupCount(), oldCall,
-          oldCall, false));
+      newCalls.add(buildAggregateCall(oldAggRel.getInput(), oldCall, oldCall.getArgList(), oldAggRel.getGroupCount(),
+          AggType.LEAF));
     }
-    return newCalls;
+    ImmutableList<RelHint> oldHints = oldAggRel.getHints();
+    ImmutableList<RelHint> newHints =
+        new ImmutableList.Builder<RelHint>().addAll(oldHints).add(createAggHint(AggType.LEAF)).build();
+    return new LogicalAggregate(oldAggRel.getCluster(), oldAggRel.getTraitSet(), newHints, oldAggRel.getInput(),
+        oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), newCalls);
   }
 
-  private RelNode makeNewIntermediateAgg(RelOptRuleCall ruleCall, Aggregate oldAggRel, PinotLogicalExchange exchange,
-      AggType aggType, @Nullable List<Integer> argList, @Nullable List<Integer> groupByList) {
-
+  private RelNode convertAggFromIntermediateInput(RelOptRuleCall ruleCall, Aggregate oldAggRel,
+      PinotLogicalExchange exchange, AggType aggType) {
     // add the exchange as the input node to the relation builder.
     RelBuilder relBuilder = ruleCall.builder();
     relBuilder.push(exchange);
 
-    // make input ref to the exchange after the leaf aggregate.
+    // make input ref to the exchange after the leaf aggregate, all groups should be at the front
     RexBuilder rexBuilder = exchange.getCluster().getRexBuilder();
     final int nGroups = oldAggRel.getGroupCount();
     for (int i = 0; i < nGroups; i++) {
       rexBuilder.makeInputRef(oldAggRel, i);
     }
 
-    // create new aggregate function calls from exchange input.
-    List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
     List<AggregateCall> newCalls = new ArrayList<>();
     Map<AggregateCall, RexNode> aggCallMapping = new HashMap<>();
 
-    for (int oldCallIndex = 0; oldCallIndex < oldCalls.size(); oldCallIndex++) {
-      AggregateCall oldCall = oldCalls.get(oldCallIndex);
-      convertIntermediateAggCall(rexBuilder, oldAggRel, oldCallIndex, oldCall, newCalls, aggCallMapping,
-          aggType, argList, exchange);
-    }
-
-    // create new aggregate relation.
-    ImmutableList<RelHint> orgHints = oldAggRel.getHints();
-    // if the aggregation isn't split between intermediate and final stages, indicate that this is a single stage
-    // aggregation so that the execution engine knows whether to aggregate or merge
-    ImmutableList<RelHint> newIntermediateAggHints =
-        new ImmutableList.Builder<RelHint>().addAll(orgHints).add(createAggHint(aggType)).build();
-    ImmutableBitSet groupSet = groupByList == null ? ImmutableBitSet.range(nGroups) : ImmutableBitSet.of(groupByList);
-    relBuilder.aggregate(
-        relBuilder.groupKey(groupSet, ImmutableList.of(groupSet)),
-        newCalls);
-    relBuilder.hints(newIntermediateAggHints);
-    return relBuilder.build();
-  }
-
-  /**
-   * convert aggregate call based on the intermediate stage input.
-   *
-   * <p>Note that the intermediate stage input only supports splittable aggregators such as SUM/MIN/MAX.
-   * All non-splittable aggregator must be converted into splittable aggregator first.
-   */
-  private static void convertIntermediateAggCall(RexBuilder rexBuilder, Aggregate oldAggRel, int oldCallIndex,
-      AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping,
-      AggType aggType, List<Integer> argList, PinotLogicalExchange exchange) {
-    final int nGroups = oldAggRel.getGroupCount();
-    final SqlAggFunction oldAggregation = oldCall.getAggregation();
-
-    List<Integer> newArgList;
-    if (aggType.isInputIntermediateFormat()) {
-      // Make sure COUNT in the intermediate stage takes an argument
-      List<Integer> oldArgList = (oldAggregation.getKind() == SqlKind.COUNT && !oldCall.isDistinct())
-          ? Collections.singletonList(oldCallIndex)
-          : oldCall.getArgList();
-      newArgList = convertArgList(nGroups + oldCallIndex, oldArgList);
-    } else {
-      newArgList = oldCall.getArgList().size() == 0 ? Collections.emptyList()
-          : Collections.singletonList(argList.get(oldCallIndex));
-    }
-    AggregateCall newCall = buildAggregateCall(exchange, newArgList, nGroups, oldCall, oldCall, false);
-    rexBuilder.addAggCall(newCall, nGroups, newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable);
-  }
-
-  private RelNode makeNewFinalAgg(RelOptRuleCall ruleCall, Aggregate oldAggRel, RelNode newIntAggNode) {
-    // add the intermediate agg node as the input node to the relation builder.
-    RelBuilder relBuilder = ruleCall.builder();
-    relBuilder.push(newIntAggNode);
-
-    Aggregate aggIntNode = (Aggregate) newIntAggNode;
-
-    // make input ref to the intermediate agg node after the leaf aggregate.
-    RexBuilder rexBuilder = newIntAggNode.getCluster().getRexBuilder();
-    final int nGroups = aggIntNode.getGroupCount();
-    for (int i = 0; i < nGroups; i++) {
-      rexBuilder.makeInputRef(aggIntNode, i);
-    }
-
-    // create new aggregate function calls from intermediate agg input.
+    // create new aggregate function calls from exchange input, all aggCalls are followed one by one from exchange
+    // b/c the exchange produces intermediate results, thus the input to the newCall will be indexed at
+    // [nGroup + oldCallIndex]
     List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
-    List<AggregateCall> oldCallsIntAgg = aggIntNode.getAggCallList();
-    List<AggregateCall> newCalls = new ArrayList<>();
-    Map<AggregateCall, RexNode> aggCallMapping = new HashMap<>();
-
     for (int oldCallIndex = 0; oldCallIndex < oldCalls.size(); oldCallIndex++) {
       AggregateCall oldCall = oldCalls.get(oldCallIndex);
-      AggregateCall oldCallIntAgg = oldCallsIntAgg.get(oldCallIndex);
-      convertFinalAggCall(rexBuilder, aggIntNode, oldCallIndex, oldCall, oldCallIntAgg, newCalls, aggCallMapping);
+      // intermediate stage input only supports single argument inputs.
+      List<Integer> argList = Collections.singletonList(nGroups + oldCallIndex);
+      AggregateCall newCall = buildAggregateCall(exchange, oldCall, argList, nGroups, aggType);
+      rexBuilder.addAggCall(newCall, nGroups, newCalls, aggCallMapping, oldAggRel.getInput()::fieldIsNullable);
     }
 
     // create new aggregate relation.
     ImmutableList<RelHint> orgHints = oldAggRel.getHints();
-    ImmutableList<RelHint> newIntermediateAggHints =
-        new ImmutableList.Builder<RelHint>().addAll(orgHints).add(createAggHint(AggType.FINAL)).build();
+    ImmutableList<RelHint> newAggHint =
+        new ImmutableList.Builder<RelHint>().addAll(orgHints).add(createAggHint(aggType)).build();
     ImmutableBitSet groupSet = ImmutableBitSet.range(nGroups);
-    relBuilder.aggregate(
-        relBuilder.groupKey(groupSet, ImmutableList.of(groupSet)),
-        newCalls);
-    relBuilder.hints(newIntermediateAggHints);
+    relBuilder.aggregate(relBuilder.groupKey(groupSet, ImmutableList.of(groupSet)), newCalls);
+    relBuilder.hints(newAggHint);
     return relBuilder.build();
   }
 
-  /**
-   * convert aggregate call based on the final stage input.
-   */
-  private static void convertFinalAggCall(RexBuilder rexBuilder, Aggregate inputAggRel, int oldCallIndex,
-      AggregateCall oldCall, AggregateCall oldCallIntAgg, List<AggregateCall> newCalls, Map<AggregateCall,
-      RexNode> aggCallMapping) {
-    final int nGroups = inputAggRel.getGroupCount();
-    final SqlAggFunction oldAggregation = oldCallIntAgg.getAggregation();
-    // Make sure COUNT in the final stage takes an argument
-    List<Integer> oldArgList = (oldAggregation.getKind() == SqlKind.COUNT && !oldCallIntAgg.isDistinct())
-        ? Collections.singletonList(oldCallIndex)
-        : oldCallIntAgg.getArgList();
-    List<Integer> newArgList = convertArgList(nGroups + oldCallIndex, oldArgList);
-    AggregateCall newCall = buildAggregateCall(inputAggRel, newArgList, nGroups, oldCallIntAgg, oldCall, true);
-    rexBuilder.addAggCall(newCall, nGroups, newCalls, aggCallMapping, inputAggRel.getInput()::fieldIsNullable);
-  }
-
-  private static AggregateCall buildAggregateCall(RelNode input, List<Integer> newArgList, int numberGroups,
-      AggregateCall inputCall, AggregateCall functionNameCall, boolean isFinalStage) {
-    final SqlAggFunction oldAggregation = functionNameCall.getAggregation();
-    final SqlKind aggKind = oldAggregation.getKind();
-    String functionName = getFunctionNameFromAggregateCall(functionNameCall);
+  private static AggregateCall buildAggregateCall(RelNode inputNode, AggregateCall orgAggCall, List<Integer> argList,
+      int numberGroups, AggType aggType) {
+    final SqlAggFunction oldAggFunction = orgAggCall.getAggregation();
+    final SqlKind aggKind = oldAggFunction.getKind();
+    String functionName = getFunctionNameFromAggregateCall(orgAggCall);
+    AggregationFunctionType functionType = AggregationFunctionType.getAggregationFunctionType(functionName);
     // Check only the supported AGG functions are provided.
-    validateAggregationFunctionIsSupported(functionName, aggKind);
-
-    AggregationFunctionType type = AggregationFunctionType.getAggregationFunctionType(functionName);
-    // Use the actual function name and return type for final stage to ensure that for aggregation functions that share
-    // leaf and intermediate functions, we can correctly extract the correct final result. e.g. KURTOSIS and SKEWNESS
-    // both use FOURTHMOMENT
-    String aggregationFunctionName = isFinalStage ? type.getName().toUpperCase(Locale.ROOT)
-        : type.getIntermediateFunctionName().toUpperCase(Locale.ROOT);
-    SqlReturnTypeInference returnTypeInference = isFinalStage ? type.getSqlReturnTypeInference()
-        : type.getSqlIntermediateReturnTypeInference();
-    SqlAggFunction sqlAggFunction =
-        new PinotSqlAggFunction(aggregationFunctionName, type.getSqlIdentifier(), type.getSqlKind(),
-            returnTypeInference, type.getSqlOperandTypeInference(), type.getSqlOperandTypeChecker(),
-            type.getSqlFunctionCategory());
+    validateAggregationFunctionIsSupported(functionType.getName(), aggKind);
+    // create the aggFunction
+    SqlAggFunction sqlAggFunction;
+    if (functionType.getIntermediateReturnTypeInference() != null) {
+      switch (aggType) {
+        case LEAF:
+          sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null,
+              functionType.getSqlKind(), functionType.getIntermediateReturnTypeInference(), null,
+              functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory());
+          break;
+        case INTERMEDIATE:
+          sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null,
+              functionType.getSqlKind(), functionType.getIntermediateReturnTypeInference(), null,
+              OperandTypes.ANY, functionType.getSqlFunctionCategory());
+          break;
+        case FINAL:
+          sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null,
+              functionType.getSqlKind(), ReturnTypes.explicit(orgAggCall.getType()), null,
+              OperandTypes.ANY, functionType.getSqlFunctionCategory());
+          break;
+        default:
+          throw new UnsupportedOperationException("Unsuppoted aggType: " + aggType + " for " + functionName);
+      }
+    } else {
+      sqlAggFunction = oldAggFunction;
+    }
 
     return AggregateCall.create(sqlAggFunction,
-        functionName.equals("distinctCount") || inputCall.isDistinct(),
-        inputCall.isApproximate(),
-        inputCall.ignoreNulls(),
-        newArgList,
-        inputCall.filterArg,
-        inputCall.distinctKeys,
-        inputCall.collation,
+        functionName.equals("distinctCount") || orgAggCall.isDistinct(),
+        orgAggCall.isApproximate(),
+        orgAggCall.ignoreNulls(),
+        argList,
+        orgAggCall.filterArg,
+        orgAggCall.distinctKeys,
+        orgAggCall.collation,
         numberGroups,
-        input,
+        inputNode,
         null,
         null);
   }
 
-  private static List<Integer> convertArgList(int oldCallIndexWithShift, List<Integer> argList) {
-    Preconditions.checkArgument(argList.size() <= 1,
-        "Unable to convert call as the argList contains more than 1 argument");
-    return argList.size() == 1 ? Collections.singletonList(oldCallIndexWithShift) : Collections.emptyList();
-  }
-
   private static String getFunctionNameFromAggregateCall(AggregateCall aggregateCall) {
     return aggregateCall.getAggregation().getName().equalsIgnoreCase("COUNT") && aggregateCall.isDistinct()
         ? "distinctCount" : aggregateCall.getAggregation().getName();
   }
 
   private static void validateAggregationFunctionIsSupported(String functionName, SqlKind aggKind) {
-    Preconditions.checkState(PinotOperatorTable.isAggregationFunctionRegisteredWithOperatorTable(functionName)
-            || PinotOperatorTable.isAggregationKindSupported(aggKind),
+    Preconditions.checkState(PinotOperatorTable.isAggregationFunctionRegisteredWithOperatorTable(functionName),
         String.format("Failed to create aggregation. Unsupported SQL aggregation kind: %s or function name: %s. "
                 + "Only splittable aggregation functions are supported!", aggKind, functionName));
   }
 
-  private RelNode createPlanWithoutLeafAggregation(RelOptRuleCall call) {
-    Aggregate oldAggRel = call.rel(0);
-    RelNode childRel = ((HepRelVertex) oldAggRel.getInput()).getCurrentRel();
-    LogicalProject project;
-
-    List<Integer> newAggArgColumns = new ArrayList<>();
-    List<Integer> newAggGroupByColumns = new ArrayList<>();
-
-    // 1. Create the LogicalProject node if it does not exist. This is to send only the relevant columns over
-    //    the wire for intermediate aggregation.
-    if (childRel instanceof Project) {
-      // Avoid creating a new LogicalProject if the child node of aggregation is already a project node.
-      project = (LogicalProject) childRel;
-      newAggArgColumns = fetchNewAggArgCols(oldAggRel.getAggCallList());
-      newAggGroupByColumns = oldAggRel.getGroupSet().asList();
-    } else {
-      // Create a leaf stage project. This is done so that only the required columns are sent over the wire for
-      // intermediate aggregation. If there are multiple aggregations on the same column, the column is projected
-      // only once.
-      project = createLogicalProjectForAggregate(oldAggRel, newAggArgColumns, newAggGroupByColumns);
-    }
-
-    // 2. Create an exchange on top of the LogicalProject.
-    PinotLogicalExchange exchange = PinotLogicalExchange.create(project, RelDistributions.hash(newAggGroupByColumns));
-
-    // 3. Create an intermediate stage aggregation.
-    RelNode newAggNode =
-        makeNewIntermediateAgg(call, oldAggRel, exchange, AggType.LEAF, newAggArgColumns, newAggGroupByColumns);
-
-    // 4. Create the final agg stage node on top of the intermediate agg if aggregations are present.
-    RelNode transformToAgg = newAggNode;
-    if (oldAggRel.getAggCallList() != null && oldAggRel.getAggCallList().size() > 0) {
-      transformToAgg = makeNewFinalAgg(call, oldAggRel, newAggNode);
-    }
-    return transformToAgg;
-  }
-
-  private LogicalProject createLogicalProjectForAggregate(Aggregate oldAggRel, List<Integer> newAggArgColumns,
-      List<Integer> newAggGroupByCols) {
-    RelNode childRel = ((HepRelVertex) oldAggRel.getInput()).getCurrentRel();
-    RexBuilder childRexBuilder = childRel.getCluster().getRexBuilder();
-    List<RelDataTypeField> fieldList = childRel.getRowType().getFieldList();
-
-    List<RexNode> projectColRexNodes = new ArrayList<>();
-    List<String> projectColNames = new ArrayList<>();
-    // Maintains a mapping from the column to the corresponding index in projectColRexNodes.
-    Map<Integer, Integer> projectSet = new HashMap<>();
-
-    int projectIndex = 0;
-    for (int group : oldAggRel.getGroupSet().asSet()) {
-      projectColNames.add(fieldList.get(group).getName());
-      projectColRexNodes.add(childRexBuilder.makeInputRef(childRel, group));
-      projectSet.put(group, projectColRexNodes.size() - 1);
-      newAggGroupByCols.add(projectIndex++);
-    }
-
-    List<AggregateCall> oldAggCallList = oldAggRel.getAggCallList();
-    for (AggregateCall aggregateCall : oldAggCallList) {
-      List<Integer> argList = aggregateCall.getArgList();
-      if (argList.size() == 0) {
-        newAggArgColumns.add(-1);
-        continue;
-      }
-      for (Integer col : argList) {
-        if (!projectSet.containsKey(col)) {
-          projectColRexNodes.add(childRexBuilder.makeInputRef(childRel, col));
-          projectColNames.add(fieldList.get(col).getName());
-          projectSet.put(col, projectColRexNodes.size() - 1);
-          newAggArgColumns.add(projectColRexNodes.size() - 1);
-        } else {
-          newAggArgColumns.add(projectSet.get(col));
-        }
-      }
-    }
-
-    return LogicalProject.create(childRel, Collections.emptyList(), projectColRexNodes, projectColNames);
-  }
-
-  private List<Integer> fetchNewAggArgCols(List<AggregateCall> oldAggCallList) {
-    List<Integer> newAggArgColumns = new ArrayList<>();
-
-    for (AggregateCall aggregateCall : oldAggCallList) {
-      if (aggregateCall.getArgList().size() == 0) {
-        // This can be true for COUNT. Add a placeholder value which will be ignored.
-        newAggArgColumns.add(-1);
-        continue;
-      }
-      newAggArgColumns.addAll(aggregateCall.getArgList());
-    }
-
-    return newAggArgColumns;
-  }
-
   private static RelHint createAggHint(AggType aggType) {
     return RelHint.builder(PinotHintOptions.INTERNAL_AGG_OPTIONS)
         .hintOption(PinotHintOptions.InternalAggregateOptions.AGG_TYPE, aggType.name())
diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateReduceFunctionsRule.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateReduceFunctionsRule.java
deleted file mode 100644
index a0d9634667..0000000000
--- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotAggregateReduceFunctionsRule.java
+++ /dev/null
@@ -1,427 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.calcite.rel.rules;
-
-import com.google.common.collect.ImmutableSet;
-import java.math.BigDecimal;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Locale;
-import java.util.Map;
-import java.util.Set;
-import org.apache.calcite.plan.RelOptRule;
-import org.apache.calcite.plan.RelOptRuleCall;
-import org.apache.calcite.rel.core.Aggregate;
-import org.apache.calcite.rel.core.AggregateCall;
-import org.apache.calcite.rel.type.RelDataType;
-import org.apache.calcite.rel.type.RelDataTypeFactory;
-import org.apache.calcite.rex.RexBuilder;
-import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.sql.PinotSqlAggFunction;
-import org.apache.calcite.sql.SqlAggFunction;
-import org.apache.calcite.sql.SqlFunction;
-import org.apache.calcite.sql.SqlFunctionCategory;
-import org.apache.calcite.sql.SqlKind;
-import org.apache.calcite.sql.fun.SqlStdOperatorTable;
-import org.apache.calcite.tools.RelBuilder;
-import org.apache.calcite.tools.RelBuilderFactory;
-import org.apache.calcite.util.CompositeList;
-import org.apache.calcite.util.Util;
-import org.apache.pinot.segment.spi.AggregationFunctionType;
-
-
-/**
- * Note: This class copies the logic for reducing SUM and AVG from {@link AggregateReduceFunctionsRule} with some
- * changes to use our Pinot defined operand type checkers and return types. This is necessary otherwise the v1
- * aggregations won't work in the v2 engine due to type issues. Once we fix the return types for the v1 aggregation
- * functions the logic for AVG and SUM can be removed. We also had to resort to using an AVG_REDUCE scalar function
- * due to null handling issues with DIVIDE (returning null on count = 0 via a CASE statement was also not possible
- * as the types of the columns were all non-null and Calcite marks nullable and non-nullable columns as incompatible).
- *
- * We added additional logic to handle typecasting MIN / MAX functions for EVERY / SOME aggregation functions in Calcite
- * which internally uses MIN / MAX with boolean return types. This was necessary because the v1 aggregations for
- * MIN / MAX always return DOUBLE and this caused type issues for certain queries that utilize Calcite's EVERY / SOME
- * aggregation functions.
- *
- * Planner rule that reduces aggregate functions in
- * {@link org.apache.calcite.rel.core.Aggregate}s to simpler forms.
- *
- * <p>Rewrites:
- * <ul>
- *
- * <li>AVG(x) &rarr; SUM(x) / COUNT(x)
- *
- * </ul>
- *
- * <p>Since many of these rewrites introduce multiple occurrences of simpler
- * forms like {@code COUNT(x)}, the rule gathers common sub-expressions as it
- * goes.
- *
- * @see CoreRules#AGGREGATE_REDUCE_FUNCTIONS
- */
-public class PinotAggregateReduceFunctionsRule
-    extends RelOptRule {
-
-  public static final PinotAggregateReduceFunctionsRule INSTANCE =
-      new PinotAggregateReduceFunctionsRule(PinotRuleUtils.PINOT_REL_FACTORY);
-  //~ Static fields/initializers ---------------------------------------------
-
-  protected PinotAggregateReduceFunctionsRule(RelBuilderFactory factory) {
-    super(operand(Aggregate.class, any()), factory, null);
-  }
-
-  private final Set<SqlKind> _functionsToReduce = ImmutableSet.<SqlKind>builder().addAll(SqlKind.AVG_AGG_FUNCTIONS)
-      .add(SqlKind.SUM).add(SqlKind.MAX).add(SqlKind.MIN).build();
-
-  //~ Constructors -----------------------------------------------------------
-
-
-
-  //~ Methods ----------------------------------------------------------------
-
-  @Override public boolean matches(RelOptRuleCall call) {
-    if (!super.matches(call)) {
-      return false;
-    }
-    Aggregate oldAggRel = (Aggregate) call.rels[0];
-    return containsAvgStddevVarCall(oldAggRel.getAggCallList());
-  }
-
-  @Override public void onMatch(RelOptRuleCall ruleCall) {
-    Aggregate oldAggRel = (Aggregate) ruleCall.rels[0];
-    reduceAggs(ruleCall, oldAggRel);
-  }
-
-  /**
-   * Returns whether any of the aggregates are calls to AVG, STDDEV_*, VAR_*.
-   *
-   * @param aggCallList List of aggregate calls
-   */
-  private boolean containsAvgStddevVarCall(List<AggregateCall> aggCallList) {
-    return aggCallList.stream().anyMatch(this::canReduce);
-  }
-
-  /** Returns whether this rule can reduce a given aggregate function call. */
-  public boolean canReduce(AggregateCall call) {
-    return _functionsToReduce.contains(call.getAggregation().getKind());
-  }
-
-  /**
-   * Reduces calls to functions AVG, SUM, MIN, MAX if the function is
-   * present in {@link PinotAggregateReduceFunctionsRule#_functionsToReduce}
-   *
-   * <p>It handles newly generated common subexpressions since this was done
-   * at the sql2rel stage.
-   */
-  private void reduceAggs(
-      RelOptRuleCall ruleCall,
-      Aggregate oldAggRel) {
-    RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
-
-    List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
-    final int groupCount = oldAggRel.getGroupCount();
-
-    final List<AggregateCall> newCalls = new ArrayList<>();
-    final Map<AggregateCall, RexNode> aggCallMapping = new HashMap<>();
-
-    final List<RexNode> projList = new ArrayList<>();
-
-    // pass through group key
-    for (int i = 0; i < groupCount; i++) {
-      projList.add(rexBuilder.makeInputRef(oldAggRel, i));
-    }
-
-    // List of input expressions. If a particular aggregate needs more, it
-    // will add an expression to the end, and we will create an extra
-    // project.
-    final RelBuilder relBuilder = ruleCall.builder();
-    relBuilder.push(oldAggRel.getInput());
-    final List<RexNode> inputExprs = new ArrayList<>(relBuilder.fields());
-
-    // create new aggregate function calls and rest of project list together
-    for (AggregateCall oldCall : oldCalls) {
-      projList.add(
-          reduceAgg(
-              oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
-    }
-
-    final int extraArgCount =
-        inputExprs.size() - relBuilder.peek().getRowType().getFieldCount();
-    if (extraArgCount > 0) {
-      relBuilder.project(inputExprs,
-          CompositeList.of(
-              relBuilder.peek().getRowType().getFieldNames(),
-              Collections.nCopies(extraArgCount, null)));
-    }
-    newAggregateRel(relBuilder, oldAggRel, newCalls);
-    newCalcRel(relBuilder, oldAggRel.getRowType(), projList);
-    ruleCall.transformTo(relBuilder.build());
-  }
-
-  private RexNode reduceAgg(
-      Aggregate oldAggRel,
-      AggregateCall oldCall,
-      List<AggregateCall> newCalls,
-      Map<AggregateCall, RexNode> aggCallMapping,
-      List<RexNode> inputExprs) {
-    if (canReduce(oldCall)) {
-      final Integer y;
-      final Integer x;
-      final SqlKind kind = oldCall.getAggregation().getKind();
-      switch (kind) {
-        case SUM:
-          // replace original SUM(x) with
-          // case COUNT(x) when 0 then null else SUM0(x) end
-          return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
-        case AVG:
-          // replace original AVG(x) with SUM(x) / COUNT(x) via an AVG_REDUCE scalar function for null handling
-          return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs);
-        case MIN:
-        case MAX:
-          // typecast to oldCall type (BOOLEAN) if needed to handle EVERY / SOME aggregations. This is essentially a
-          // no-op typecasting for normal MIN / MAX aggregations.
-          return reduceMinMax(oldAggRel, oldCall, newCalls, aggCallMapping);
-        default:
-          throw Util.unexpected(kind);
-      }
-    } else {
-      // anything else:  preserve original call
-      RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
-      final int nGroups = oldAggRel.getGroupCount();
-      return rexBuilder.addAggCall(oldCall,
-          nGroups,
-          newCalls,
-          aggCallMapping,
-          oldAggRel.getInput()::fieldIsNullable);
-    }
-  }
-
-  private static RexNode reduceAvg(
-      Aggregate oldAggRel,
-      AggregateCall oldCall,
-      List<AggregateCall> newCalls,
-      Map<AggregateCall, RexNode> aggCallMapping,
-      @SuppressWarnings("unused") List<RexNode> inputExprs) {
-    final int nGroups = oldAggRel.getGroupCount();
-    final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
-
-    AggregationFunctionType functionTypeSum = AggregationFunctionType.SUM;
-    SqlAggFunction sumAggFunc = new PinotSqlAggFunction(functionTypeSum.getName().toUpperCase(),
-        functionTypeSum.getSqlIdentifier(), functionTypeSum.getSqlKind(), functionTypeSum.getSqlReturnTypeInference(),
-        functionTypeSum.getSqlOperandTypeInference(), functionTypeSum.getSqlOperandTypeChecker(),
-        functionTypeSum.getSqlFunctionCategory());
-
-    final AggregateCall sumCall =
-        AggregateCall.create(sumAggFunc,
-            oldCall.isDistinct(),
-            oldCall.isApproximate(),
-            oldCall.ignoreNulls(),
-            oldCall.getArgList(),
-            oldCall.filterArg,
-            oldCall.distinctKeys,
-            oldCall.collation,
-            oldAggRel.getGroupCount(),
-            oldAggRel.getInput(),
-            null,
-            null);
-    final AggregateCall countCall =
-        AggregateCall.create(SqlStdOperatorTable.COUNT,
-            oldCall.isDistinct(),
-            oldCall.isApproximate(),
-            oldCall.ignoreNulls(),
-            oldCall.getArgList(),
-            oldCall.filterArg,
-            oldCall.distinctKeys,
-            oldCall.collation,
-            oldAggRel.getGroupCount(),
-            oldAggRel.getInput(),
-            null,
-            null);
-
-    // NOTE:  these references are with respect to the output
-    // of newAggRel
-    RexNode numeratorRef =
-        rexBuilder.addAggCall(sumCall,
-            nGroups,
-            newCalls,
-            aggCallMapping,
-            oldAggRel.getInput()::fieldIsNullable);
-    final RexNode denominatorRef =
-        rexBuilder.addAggCall(countCall,
-            nGroups,
-            newCalls,
-            aggCallMapping,
-            oldAggRel.getInput()::fieldIsNullable);
-
-    final RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
-    final RelDataType avgType = typeFactory.createTypeWithNullability(
-        oldCall.getType(), numeratorRef.getType().isNullable());
-    numeratorRef = rexBuilder.ensureType(avgType, numeratorRef, true);
-
-    // TODO: Find a way to correctly use the DIVIDE binary operator instead of reduce function with a COUNT = 0
-    //       check. Today this does not work due to all types being declared as "NOT NULL", but returning null violates
-    //       this
-    // Special casing AVG to use a scalar function for reducing the results.
-    AggregationFunctionType type =
-        AggregationFunctionType.getAggregationFunctionType(oldCall.getAggregation().getName());
-    SqlFunction function = new SqlFunction(type.getReduceFunctionName().toUpperCase(Locale.ROOT),
-        SqlKind.OTHER_FUNCTION, type.getSqlReduceReturnTypeInference(), null,
-        type.getSqlReduceOperandTypeChecker(), SqlFunctionCategory.USER_DEFINED_FUNCTION);
-    List<RexNode> functionArgs = Arrays.asList(numeratorRef, denominatorRef);
-
-    // Use our own reducer instead of divide for null/0 count handling
-    final RexNode reduceRef = rexBuilder.makeCall(function, functionArgs);
-    return rexBuilder.makeCast(oldCall.getType(), reduceRef);
-  }
-
-  private static RexNode reduceSum(
-      Aggregate oldAggRel,
-      AggregateCall oldCall,
-      List<AggregateCall> newCalls,
-      Map<AggregateCall, RexNode> aggCallMapping) {
-    final int nGroups = oldAggRel.getGroupCount();
-    RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
-
-    AggregationFunctionType functionTypeSum = AggregationFunctionType.SUM0;
-    SqlAggFunction sumAggFunc = new PinotSqlAggFunction(functionTypeSum.getName().toUpperCase(),
-        functionTypeSum.getSqlIdentifier(), functionTypeSum.getSqlKind(), functionTypeSum.getSqlReturnTypeInference(),
-        functionTypeSum.getSqlOperandTypeInference(), functionTypeSum.getSqlOperandTypeChecker(),
-        functionTypeSum.getSqlFunctionCategory());
-
-    final AggregateCall sumZeroCall =
-        AggregateCall.create(sumAggFunc, oldCall.isDistinct(),
-            oldCall.isApproximate(), oldCall.ignoreNulls(),
-            oldCall.getArgList(), oldCall.filterArg, oldCall.distinctKeys,
-            oldCall.collation, oldAggRel.getGroupCount(), oldAggRel.getInput(),
-            null, oldCall.name);
-    final AggregateCall countCall =
-        AggregateCall.create(SqlStdOperatorTable.COUNT,
-            oldCall.isDistinct(),
-            oldCall.isApproximate(),
-            oldCall.ignoreNulls(),
-            oldCall.getArgList(),
-            oldCall.filterArg,
-            oldCall.distinctKeys,
-            oldCall.collation,
-            oldAggRel.getGroupCount(),
-            oldAggRel,
-            null,
-            null);
-
-    // NOTE:  these references are with respect to the output
-    // of newAggRel
-    RexNode sumZeroRef =
-        rexBuilder.addAggCall(sumZeroCall,
-            nGroups,
-            newCalls,
-            aggCallMapping,
-            oldAggRel.getInput()::fieldIsNullable);
-    if (!oldCall.getType().isNullable()) {
-      // If SUM(x) is not nullable, the validator must have determined that
-      // nulls are impossible (because the group is never empty and x is never
-      // null). Therefore we translate to SUM0(x).
-      return sumZeroRef;
-    }
-    RexNode countRef =
-        rexBuilder.addAggCall(countCall,
-            nGroups,
-            newCalls,
-            aggCallMapping,
-            oldAggRel.getInput()::fieldIsNullable);
-    return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
-        rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
-            countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)),
-        rexBuilder.makeNullLiteral(sumZeroRef.getType()),
-        sumZeroRef);
-  }
-
-  private static RexNode reduceMinMax(
-      Aggregate oldAggRel,
-      AggregateCall oldCall,
-      List<AggregateCall> newCalls,
-      Map<AggregateCall, RexNode> aggCallMapping) {
-    final int nGroups = oldAggRel.getGroupCount();
-    final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
-    String functionName = oldCall.getAggregation().getName();
-
-    AggregationFunctionType functionType = AggregationFunctionType.getAggregationFunctionType(functionName);
-    SqlAggFunction aggFunc = new PinotSqlAggFunction(functionType.getName().toUpperCase(),
-        functionType.getSqlIdentifier(), functionType.getSqlKind(), functionType.getSqlReturnTypeInference(),
-        functionType.getSqlOperandTypeInference(), functionType.getSqlOperandTypeChecker(),
-        functionType.getSqlFunctionCategory());
-
-    final AggregateCall newCall =
-        AggregateCall.create(aggFunc,
-            oldCall.isDistinct(),
-            oldCall.isApproximate(),
-            oldCall.ignoreNulls(),
-            oldCall.getArgList(),
-            oldCall.filterArg,
-            oldCall.distinctKeys,
-            oldCall.collation,
-            oldAggRel.getGroupCount(),
-            oldAggRel.getInput(),
-            null,
-            null);
-
-    RexNode ref =
-        rexBuilder.addAggCall(newCall,
-            nGroups,
-            newCalls,
-            aggCallMapping,
-            oldAggRel.getInput()::fieldIsNullable);
-    return rexBuilder.makeCast(oldCall.getType(), ref);
-  }
-
-  /**
-   * Does a shallow clone of oldAggRel and updates aggCalls. Could be refactored
-   * into Aggregate and subclasses - but it's only needed for some
-   * subclasses.
-   *
-   * @param relBuilder Builder of relational expressions; at the top of its
-   *                   stack is its input
-   * @param oldAggregate LogicalAggregate to clone.
-   * @param newCalls  New list of AggregateCalls
-   */
-  protected void newAggregateRel(RelBuilder relBuilder,
-      Aggregate oldAggregate,
-      List<AggregateCall> newCalls) {
-    relBuilder.aggregate(
-        relBuilder.groupKey(oldAggregate.getGroupSet(), oldAggregate.getGroupSets()),
-        newCalls);
-  }
-
-  /**
-   * Adds a calculation with the expressions to compute the original aggregate
-   * calls from the decomposed ones.
-   *
-   * @param relBuilder Builder of relational expressions; at the top of its
-   *                   stack is its input
-   * @param rowType The output row type of the original aggregate.
-   * @param exprs The expressions to compute the original aggregate calls
-   */
-  protected void newCalcRel(RelBuilder relBuilder,
-      RelDataType rowType,
-      List<RexNode> exprs) {
-    relBuilder.project(exprs, rowType.getFieldNames());
-  }
-}
diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
index 76ca322edf..ffde1200aa 100644
--- a/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
+++ b/pinot-query-planner/src/main/java/org/apache/calcite/rel/rules/PinotQueryRuleSets.java
@@ -86,7 +86,7 @@ public class PinotQueryRuleSets {
           CoreRules.AGGREGATE_UNION_AGGREGATE,
 
           // reduce aggregate functions like AVG, STDDEV_POP etc.
-          PinotAggregateReduceFunctionsRule.INSTANCE
+          CoreRules.AGGREGATE_REDUCE_FUNCTIONS
           );
 
   // Filter pushdown rules run using a RuleCollection since we want to push down a filter as much as possible in a
diff --git a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java b/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java
index dd772c349c..8cd1289f5d 100644
--- a/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java
+++ b/pinot-query-planner/src/main/java/org/apache/calcite/sql/fun/PinotOperatorTable.java
@@ -28,12 +28,10 @@ import java.util.stream.Collectors;
 import java.util.stream.Stream;
 import org.apache.calcite.sql.PinotSqlAggFunction;
 import org.apache.calcite.sql.SqlFunction;
-import org.apache.calcite.sql.SqlFunctionCategory;
 import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.SqlOperator;
 import org.apache.calcite.sql.validate.SqlNameMatchers;
 import org.apache.calcite.util.Util;
-import org.apache.commons.lang3.StringUtils;
 import org.apache.pinot.segment.spi.AggregationFunctionType;
 import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
 
@@ -60,17 +58,9 @@ public class PinotOperatorTable extends SqlStdOperatorTable {
       Arrays.stream(AggregationFunctionType.values()).filter(func -> func.getSqlKind() != null)
           .flatMap(func -> Stream.of(func.getSqlKind())).collect(Collectors.toSet());
 
-  private static final Set<String> AGGREGATION_REDUCE_SUPPORTED_FUNCTIONS =
-      Arrays.stream(AggregationFunctionType.values())
-          .filter(func -> func.getReduceFunctionName() != null)
-          .flatMap(func -> Stream.of(func.name(), func.getName(), func.getName().toUpperCase(),
-              func.getName().toLowerCase()))
-          .collect(Collectors.toSet());
-
-  // TODO: This is needed until all aggregation functions are registered with the operator table. SqlKind cannot be
-  //       null for all registered calcite functions
   private static final Set<String> CALCITE_OPERATOR_TABLE_REGISTERED_FUNCTIONS =
-      Arrays.stream(AggregationFunctionType.values()).filter(func -> func.getSqlKind() != null)
+      Arrays.stream(AggregationFunctionType.values())
+          .filter(func -> func.getSqlKind() != null) // TODO: remove this once all V1 AGG functions are registered
           .flatMap(func -> Stream.of(func.name(), func.getName(), func.getName().toUpperCase(),
               func.getName().toLowerCase()))
           .collect(Collectors.toSet());
@@ -88,24 +78,8 @@ public class PinotOperatorTable extends SqlStdOperatorTable {
     return _instance;
   }
 
-  public static boolean isAggregationKindSupported(SqlKind sqlKind) {
-    return KINDS.contains(sqlKind);
-  }
-
-  public static boolean isAggregationReduceSupported(String functionName) {
-    if (AGGREGATION_REDUCE_SUPPORTED_FUNCTIONS.contains(functionName)) {
-      return true;
-    }
-    String upperCaseFunctionName = AggregationFunctionType.getNormalizedAggregationFunctionName(functionName);
-    return AGGREGATION_REDUCE_SUPPORTED_FUNCTIONS.contains(upperCaseFunctionName);
-  }
-
   public static boolean isAggregationFunctionRegisteredWithOperatorTable(String functionName) {
-    if (CALCITE_OPERATOR_TABLE_REGISTERED_FUNCTIONS.contains(functionName)) {
-      return true;
-    }
-    String upperCaseFunctionName = AggregationFunctionType.getNormalizedAggregationFunctionName(functionName);
-    return CALCITE_OPERATOR_TABLE_REGISTERED_FUNCTIONS.contains(upperCaseFunctionName);
+    return CALCITE_OPERATOR_TABLE_REGISTERED_FUNCTIONS.contains(functionName);
   }
 
   /**
@@ -125,10 +99,7 @@ public class PinotOperatorTable extends SqlStdOperatorTable {
         if (SqlFunction.class.isAssignableFrom(field.getType())) {
           SqlFunction op = (SqlFunction) field.get(this);
           if (op != null && notRegistered(op)) {
-            if (!isPinotAggregationFunction(op.getName())) {
-              // Register the standard Calcite functions and those defined in this class only
-              register(op);
-            }
+            register(op);
           }
         } else if (
             SqlOperator.class.isAssignableFrom(field.getType())) {
@@ -142,89 +113,32 @@ public class PinotOperatorTable extends SqlStdOperatorTable {
       }
     }
 
-    // Walk through all the Pinot aggregation types and register those that are supported in multistage and which
-    // aren't standard Calcite functions such as SUM / MIN / MAX / COUNT etc.
+    // Walk through all the Pinot aggregation types and
+    //   1. register those that are supported in multistage in addition to calcite standard opt table.
+    //   2. register special handling that differs from calcite standard.
     for (AggregationFunctionType aggregationFunctionType : AggregationFunctionType.values()) {
-      if (aggregationFunctionType.isNativeCalciteAggregationFunctionType()
-          || aggregationFunctionType.getSqlKind() == null) {
-        // Skip registering functions which are standard Calcite functions and functions which are not yet supported
-        // in multistage
-        continue;
-      }
-
-      // Register the aggregation function with Calcite along with all alternative names
-      List<PinotSqlAggFunction> sqlAggFunctions = new ArrayList<>();
-      PinotSqlAggFunction aggFunction = generatePinotSqlAggFunction(aggregationFunctionType.getName(),
-          aggregationFunctionType, false);
-      sqlAggFunctions.add(aggFunction);
-      List<String> alternativeFunctionNames = aggregationFunctionType.getAlternativeNames();
-      if (alternativeFunctionNames == null || alternativeFunctionNames.size() == 0) {
-        // If no alternative function names are specified, generate one which converts camel case to have underscores
-        // as delimiters instead. E.g. boolAnd -> BOOL_AND
-        String alternativeFunctionName =
-            convertCamelCaseToUseUnderscores(aggregationFunctionType.getName());
-        PinotSqlAggFunction function = generatePinotSqlAggFunction(alternativeFunctionName, aggregationFunctionType,
-            false);
-        sqlAggFunctions.add(function);
-      } else {
+      if (aggregationFunctionType.getSqlKind() != null) {
+        // 1. Register the aggregation function with Calcite
+        registerAggregateFunction(aggregationFunctionType.getName(), aggregationFunctionType);
+        // 2. Register the aggregation function with Calcite on all alternative names
+        List<String> alternativeFunctionNames = aggregationFunctionType.getAlternativeNames();
         for (String alternativeFunctionName : alternativeFunctionNames) {
-          PinotSqlAggFunction function = generatePinotSqlAggFunction(alternativeFunctionName, aggregationFunctionType,
-              false);
-          sqlAggFunctions.add(function);
-        }
-      }
-      for (PinotSqlAggFunction sqlAggFunction : sqlAggFunctions) {
-        if (notRegistered(sqlAggFunction)) {
-          register(sqlAggFunction);
-        }
-      }
-
-      if (!StringUtils.isEmpty(aggregationFunctionType.getIntermediateFunctionName())
-          && !StringUtils.equals(aggregationFunctionType.getIntermediateFunctionName(),
-          aggregationFunctionType.getName())) {
-        // Register the intermediate function with Calcite if the name differs from the main function name
-        PinotSqlAggFunction intermediateAggFunction =
-            generatePinotSqlAggFunction(aggregationFunctionType.getIntermediateFunctionName(), aggregationFunctionType,
-            true);
-        if (notRegistered(intermediateAggFunction)) {
-          register(intermediateAggFunction);
-        }
-      }
-
-      if (!StringUtils.isEmpty(aggregationFunctionType.getReduceFunctionName())) {
-        // If a reduce function name is available, register this as a SqlFunction with Calcite
-        SqlFunction function = new SqlFunction(
-            aggregationFunctionType.getReduceFunctionName(),
-            SqlKind.OTHER_FUNCTION,
-            aggregationFunctionType.getSqlReduceReturnTypeInference(),
-            null,
-            aggregationFunctionType.getSqlReduceOperandTypeChecker(),
-            SqlFunctionCategory.USER_DEFINED_FUNCTION);
-        if (notRegistered(function)) {
-          register(function);
+          registerAggregateFunction(alternativeFunctionName, aggregationFunctionType);
         }
       }
     }
   }
 
-  private static String convertCamelCaseToUseUnderscores(String functionName) {
-    // Skip functions that have numbers for now and return their name as is
-    return functionName.matches(".*\\d.*")
-        ? functionName
-        : functionName.replaceAll("(.)(\\p{Upper}+|\\d+)", "$1_$2");
-  }
-
-  private static PinotSqlAggFunction generatePinotSqlAggFunction(String functionName,
-      AggregationFunctionType aggregationFunctionType, boolean isIntermediateStageFunction) {
-    return new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT),
-        aggregationFunctionType.getSqlIdentifier(),
-        aggregationFunctionType.getSqlKind(),
-        isIntermediateStageFunction
-            ? aggregationFunctionType.getSqlIntermediateReturnTypeInference()
-            : aggregationFunctionType.getSqlReturnTypeInference(),
-        aggregationFunctionType.getSqlOperandTypeInference(),
-        aggregationFunctionType.getSqlOperandTypeChecker(),
-        aggregationFunctionType.getSqlFunctionCategory());
+  private void registerAggregateFunction(String functionName, AggregationFunctionType functionType) {
+    // register function behavior that's different from Calcite
+    if (functionType.getOperandTypeChecker() != null && functionType.getReturnTypeInference() != null) {
+      PinotSqlAggFunction sqlAggFunction = new PinotSqlAggFunction(functionName.toUpperCase(Locale.ROOT), null,
+          functionType.getSqlKind(), functionType.getReturnTypeInference(), null,
+          functionType.getOperandTypeChecker(), functionType.getSqlFunctionCategory());
+      if (notRegistered(sqlAggFunction)) {
+        register(sqlAggFunction);
+      }
+    }
   }
 
   private boolean notRegistered(SqlFunction op) {
@@ -240,12 +154,4 @@ public class PinotOperatorTable extends SqlStdOperatorTable {
         SqlNameMatchers.withCaseSensitive(false));
     return operatorList.size() == 0;
   }
-
-  private boolean isPinotAggregationFunction(String name) {
-    AggregationFunctionType aggFunctionType = null;
-    if (isAggregationFunctionRegisteredWithOperatorTable(name)) {
-      aggFunctionType = AggregationFunctionType.getAggregationFunctionType(name);
-    }
-    return aggFunctionType != null && !aggFunctionType.isNativeCalciteAggregationFunctionType();
-  }
 }
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
index 1d73743b09..05e837ee77 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryCompilationTest.java
@@ -437,7 +437,6 @@ public class QueryCompilationTest extends QueryEnvironmentTestBase {
         new Object[]{"EXPLAIN PLAN EXCLUDING ATTRIBUTES AS DOT FOR SELECT col1, COUNT(*) FROM a GROUP BY col1",
               "Execution Plan\n"
             + "digraph {\n"
-            + "\"LogicalAggregate\\n\" -> \"LogicalAggregate\\n\" [label=\"0\"]\n"
             + "\"PinotLogicalExchange\\n\" -> \"LogicalAggregate\\n\" [label=\"0\"]\n"
             + "\"LogicalAggregate\\n\" -> \"PinotLogicalExchange\\n\" [label=\"0\"]\n"
             + "\"LogicalTableScan\\n\" -> \"LogicalAggregate\\n\" [label=\"0\"]\n"
diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
index f4c2ceafa7..135e1cb208 100644
--- a/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
+++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/QueryEnvironmentTestBase.java
@@ -93,8 +93,9 @@ public class QueryEnvironmentTestBase {
         new Object[]{"SELECT SUM(a.col3), COUNT(*) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a'"},
         new Object[]{"SELECT AVG(a.col3), SUM(a.col3), COUNT(a.col3) FROM a"},
         new Object[]{"SELECT a.col1, AVG(a.col3), SUM(a.col3), COUNT(a.col3) FROM a GROUP BY a.col1"},
-        new Object[]{"SELECT BOOL_AND(a.col5), BOOL_OR(a.col5) FROM a"},
-        new Object[]{"SELECT a.col3, BOOL_AND(a.col5), BOOL_OR(a.col5) FROM a GROUP BY a.col3"},
+        // TODO: support BOOL_AND and BOOL_OR as MIN/MAX
+//        new Object[]{"SELECT BOOL_AND(a.col5), BOOL_OR(a.col5) FROM a"},
+//        new Object[]{"SELECT a.col3, BOOL_AND(a.col5), BOOL_OR(a.col5) FROM a GROUP BY a.col3"},
         new Object[]{"SELECT KURTOSIS(a.col2), COUNT(DISTINCT a.col3), SKEWNESS(a.col3) FROM a"},
         new Object[]{"SELECT a.col1, KURTOSIS(a.col2), SKEWNESS(a.col3) FROM a GROUP BY a.col1"},
         new Object[]{"SELECT COUNT(a.col3), AVG(a.col3), SUM(a.col3), MIN(a.col3), MAX(a.col3) FROM a"},
diff --git a/pinot-query-planner/src/test/resources/queries/AggregatePlans.json b/pinot-query-planner/src/test/resources/queries/AggregatePlans.json
index 0cc18fe67f..102fb8681c 100644
--- a/pinot-query-planner/src/test/resources/queries/AggregatePlans.json
+++ b/pinot-query-planner/src/test/resources/queries/AggregatePlans.json
@@ -6,14 +6,13 @@
         "sql": "EXPLAIN PLAN FOR SELECT AVG(a.col4) as avg FROM a WHERE a.col3 >= 0 AND a.col2 = 'pink floyd'",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(avg=[CAST(AVG_REDUCE(CAST($0):DECIMAL(1000, 0) NOT NULL, $1)):DECIMAL(1000, 0)])",
+          "\nLogicalProject(avg=[/(CASE(=($1, 0), null:DECIMAL(1000, 0), $0), $1)])",
           "\n  LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
-          "\n    LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
-          "\n      PinotLogicalExchange(distribution=[hash])",
-          "\n        LogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
-          "\n          LogicalProject(col2=[$1], col3=[$2], col4=[$3])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n    PinotLogicalExchange(distribution=[hash])",
+          "\n      LogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+          "\n        LogicalProject(col2=[$1], col3=[$2], col4=[$3])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -22,14 +21,13 @@
         "sql": "EXPLAIN PLAN FOR SELECT AVG(a.col4) as avg, SUM(a.col4) as sum, MAX(a.col4) as max FROM a WHERE a.col3 >= 0 AND a.col2 = 'pink floyd'",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(avg=[CAST(AVG_REDUCE(CAST($0):DECIMAL(1000, 0) NOT NULL, $1)):DECIMAL(1000, 0)], sum=[$0], max=[$2])",
+          "\nLogicalProject(avg=[/(CASE(=($1, 0), null:DECIMAL(1000, 0), $0), $1)], sum=[CASE(=($1, 0), null:DECIMAL(1000, 0), $0)], max=[$2])",
           "\n  LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], agg#2=[MAX($2)])",
-          "\n    LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)], agg#2=[MAX($2)])",
-          "\n      PinotLogicalExchange(distribution=[hash])",
-          "\n        LogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()], agg#2=[MAX($2)])",
-          "\n          LogicalProject(col2=[$1], col3=[$2], col4=[$3])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n    PinotLogicalExchange(distribution=[hash])",
+          "\n      LogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()], agg#2=[MAX($2)])",
+          "\n        LogicalProject(col2=[$1], col3=[$2], col4=[$3])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -38,28 +36,12 @@
         "sql": "EXPLAIN PLAN FOR SELECT AVG(a.col3) as avg, COUNT(*) as count FROM a WHERE a.col3 >= 0 AND a.col2 = 'pink floyd'",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(avg=[CAST(AVG_REDUCE($0, $1)):DOUBLE], count=[$1])",
+          "\nLogicalProject(avg=[/(CAST(CASE(=($1, 0), null:INTEGER, $0)):DOUBLE, $1)], count=[$1])",
           "\n  LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
-          "\n    LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
-          "\n      PinotLogicalExchange(distribution=[hash])",
-          "\n        LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
-          "\n          LogicalProject(col2=[$1], col3=[$2])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
-          "\n              LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
-      {
-        "description": "Select aggregates with filters and select alias",
-        "sql": "EXPLAIN PLAN FOR SELECT KURTOSIS(a.col3) as kurtosis, DISTINCTCOUNT(a.col1) as dcount, MIN(a.col6), BOOL_AND(a.col5) FROM a WHERE a.col3 >= 0 AND a.col2 = 'iron maiden'",
-        "output": [
-          "Execution Plan",
-          "\nLogicalAggregate(group=[{}], agg#0=[KURTOSIS($0)], agg#1=[DISTINCTCOUNT($1)], agg#2=[MIN($2)], agg#3=[BOOLAND($3)])",
-          "\n  LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($0)], agg#1=[DISTINCTCOUNT($1)], agg#2=[MIN($2)], agg#3=[BOOLAND($3)])",
           "\n    PinotLogicalExchange(distribution=[hash])",
-          "\n      LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($2)], agg#1=[DISTINCTCOUNT($0)], agg#2=[MIN($4)], agg#3=[BOOLAND($3)])",
-          "\n        LogicalProject(col1=[$0], col2=[$1], col3=[$2], col5=[$4], col6=[$5])",
-          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'iron maiden'))])",
+          "\n      LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
+          "\n        LogicalProject(col2=[$1], col3=[$2])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
           "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
@@ -69,7 +51,7 @@
         "sql": "EXPLAIN PLAN FOR SELECT SUM(a.col3), COUNT(a.col1) FROM a",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
+          "\nLogicalProject(EXPR$0=[CASE(=($1, 0), null:INTEGER, $0)], EXPR$1=[$1])",
           "\n  LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
           "\n    PinotLogicalExchange(distribution=[hash])",
           "\n      LogicalAggregate(group=[{}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
@@ -77,26 +59,12 @@
           "\n"
         ]
       },
-      {
-        "description": "Select aggregates for COUNT variations",
-        "sql": "EXPLAIN PLAN FOR SELECT DISTINCTCOUNT(a.col3), COUNT(DISTINCT a.col3), COUNT(a.col3), COUNT(*), DISTINCTCOUNT(a.col5) FROM a",
-        "output": [
-          "Execution Plan",
-          "\nLogicalProject(EXPR$0=[$0], EXPR$1=[$1], EXPR$2=[$2], EXPR$3=[$2], EXPR$4=[$3])",
-          "\n  LogicalAggregate(group=[{}], agg#0=[DISTINCTCOUNT($0)], agg#1=[DISTINCTCOUNT(DISTINCT $1)], agg#2=[COUNT($2)], agg#3=[DISTINCTCOUNT($3)])",
-          "\n    LogicalAggregate(group=[{}], agg#0=[DISTINCTCOUNT($0)], agg#1=[DISTINCTCOUNT(DISTINCT $1)], agg#2=[COUNT($2)], agg#3=[DISTINCTCOUNT($3)])",
-          "\n      PinotLogicalExchange(distribution=[hash])",
-          "\n        LogicalAggregate(group=[{}], agg#0=[DISTINCTCOUNT($2)], agg#1=[DISTINCTCOUNT(DISTINCT $2)], agg#2=[COUNT()], agg#3=[DISTINCTCOUNT($4)])",
-          "\n          LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
       {
         "description": "Select aggregates with filters",
         "sql": "EXPLAIN PLAN FOR SELECT SUM(a.col3), COUNT(*) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a'",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
+          "\nLogicalProject(EXPR$0=[CASE(=($1, 0), null:INTEGER, $0)], EXPR$1=[$1])",
           "\n  LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
           "\n    PinotLogicalExchange(distribution=[hash])",
           "\n      LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
@@ -106,93 +74,12 @@
           "\n"
         ]
       },
-      {
-        "description": "Select transform inside aggregate with filters",
-        "sql": "EXPLAIN PLAN FOR SELECT SUM(ADD(a.col3, a.col6)), AVG(MOD(a.col6, a.col3)), MIN(ABS(a.col6)) FROM a WHERE a.col3 >= 0 AND a.col2 = 'hooloovoo'",
-        "output": [
-          "Execution Plan",
-          "\nLogicalProject(EXPR$0=[$0], EXPR$1=[CAST(AVG_REDUCE($1, $2)):DOUBLE], EXPR$2=[$3])",
-          "\n  LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[$SUM0($1)], agg#2=[COUNT($2)], agg#3=[MIN($3)])",
-          "\n    LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[$SUM0($1)], agg#2=[COUNT($2)], agg#3=[MIN($3)])",
-          "\n      PinotLogicalExchange(distribution=[hash])",
-          "\n        LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[$SUM0($1)], agg#2=[COUNT()], agg#3=[MIN($2)])",
-          "\n          LogicalProject($f0=[ADD($2, $5)], $f1=[MOD($5, $2)], $f2=[ABS($5)])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'hooloovoo'))])",
-          "\n              LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
-      {
-        "description": "Select ORDER BY on aggregate with filters",
-        "sql": "EXPLAIN PLAN FOR SELECT SUM(a.col6), AVG(a.col3), KURTOSIS(ABS(a.col6)) FROM a WHERE a.col3 >= 0 AND a.col2 = 'zaphoid beeblebrox' ORDER BY SUM(a.col6), AVG(a.col3) DESC",
-        "output": [
-          "Execution Plan",
-          "\nLogicalSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[DESC], offset=[0])",
-          "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[0, 1 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
-          "\n    LogicalSort(sort0=[$0], sort1=[$1], dir0=[ASC], dir1=[DESC])",
-          "\n      LogicalProject(EXPR$0=[$0], EXPR$1=[CAST(AVG_REDUCE($1, $2)):DOUBLE], EXPR$2=[$3])",
-          "\n        LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[$SUM0($1)], agg#2=[COUNT($2)], agg#3=[KURTOSIS($3)])",
-          "\n          LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[$SUM0($1)], agg#2=[COUNT($2)], agg#3=[FOURTHMOMENT($3)])",
-          "\n            PinotLogicalExchange(distribution=[hash])",
-          "\n              LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[$SUM0($1)], agg#2=[COUNT()], agg#3=[FOURTHMOMENT($2)])",
-          "\n                LogicalProject(col6=[$5], col3=[$2], $f2=[ABS($5)])",
-          "\n                  LogicalFilter(condition=[AND(>=($2, 0), =($1, 'zaphoid beeblebrox'))])",
-          "\n                    LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
-      {
-        "description": "Select transform inside aggregate with transform inside filters",
-        "sql": "EXPLAIN PLAN FOR SELECT SKEWNESS(ADD(a.col3, a.col6)), DISTINCTCOUNT(ABS(a.col6)) FROM a WHERE a.col3 >= 0 AND REVERSE(a.col2) = 'oovoolooh'",
-        "output": [
-          "Execution Plan",
-          "\nLogicalAggregate(group=[{}], agg#0=[SKEWNESS($0)], agg#1=[DISTINCTCOUNT($1)])",
-          "\n  LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($0)], agg#1=[DISTINCTCOUNT($1)])",
-          "\n    PinotLogicalExchange(distribution=[hash])",
-          "\n      LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($0)], agg#1=[DISTINCTCOUNT($1)])",
-          "\n        LogicalProject($f0=[ADD($2, $5)], $f1=[ABS($5)])",
-          "\n          LogicalFilter(condition=[AND(>=($2, 0), =(REVERSE($1), 'oovoolooh'))])",
-          "\n            LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
-      {
-        "description": "Select more aggregates with filters",
-        "sql": "EXPLAIN PLAN FOR SELECT SKEWNESS(a.col3), AVG(a.col3), BOOL_OR(a.col5), SUM(a.col3), AVG(a.col6), MAX(a.col6) FROM a WHERE a.col3 >= 0 AND a.col2 = 'rolling stones'",
-        "output": [
-          "Execution Plan",
-          "\nLogicalProject(EXPR$0=[$0], EXPR$1=[CAST(AVG_REDUCE($1, $2)):DOUBLE], EXPR$2=[$3], EXPR$3=[$1], EXPR$4=[CAST(AVG_REDUCE($4, $2)):DOUBLE], EXPR$5=[$5])",
-          "\n  LogicalAggregate(group=[{}], agg#0=[SKEWNESS($0)], agg#1=[$SUM0($1)], agg#2=[COUNT($2)], agg#3=[BOOLOR($3)], agg#4=[$SUM0($4)], agg#5=[MAX($5)])",
-          "\n    LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($0)], agg#1=[$SUM0($1)], agg#2=[COUNT($2)], agg#3=[BOOLOR($3)], agg#4=[$SUM0($4)], agg#5=[MAX($5)])",
-          "\n      PinotLogicalExchange(distribution=[hash])",
-          "\n        LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($1)], agg#1=[$SUM0($1)], agg#2=[COUNT()], agg#3=[BOOLOR($2)], agg#4=[$SUM0($3)], agg#5=[MAX($3)])",
-          "\n          LogicalProject(col2=[$1], col3=[$2], col5=[$4], col6=[$5])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'rolling stones'))])",
-          "\n              LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
-      {
-        "description": "Select kurtosis and skewness with limit",
-        "sql": "EXPLAIN PLAN FOR SELECT SKEWNESS(a.col3), KURTOSIS(a.col3) FROM a LIMIT 100",
-        "output": [
-          "Execution Plan",
-          "\nLogicalSort(fetch=[100])",
-          "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[]], isSortOnSender=[false], isSortOnReceiver=[false])",
-          "\n    LogicalAggregate(group=[{}], agg#0=[SKEWNESS($0)], agg#1=[KURTOSIS($1)])",
-          "\n      LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($0)], agg#1=[FOURTHMOMENT($1)])",
-          "\n        PinotLogicalExchange(distribution=[hash])",
-          "\n          LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($2)], agg#1=[FOURTHMOMENT($2)])",
-          "\n            LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
       {
         "description": "Select aggregates with filters and select alias",
         "sql": "EXPLAIN PLAN FOR SELECT SUM(a.col3) as sum, COUNT(*) as count FROM a WHERE a.col3 >= 0 AND a.col2 = 'pink floyd'",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
+          "\nLogicalProject(sum=[CASE(=($1, 0), null:INTEGER, $0)], count=[$1])",
           "\n  LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
           "\n    PinotLogicalExchange(distribution=[hash])",
           "\n      LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
@@ -207,43 +94,27 @@
         "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ AVG(a.col3) as avg, COUNT(*) as count FROM a WHERE a.col3 >= 0 AND a.col2 = 'pink floyd'",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(avg=[CAST(AVG_REDUCE($0, $1)):DOUBLE], count=[$1])",
-          "\n  LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
-          "\n    LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
-          "\n      PinotLogicalExchange(distribution=[hash])",
-          "\n        LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
-          "\n          LogicalProject(col2=[$1], col3=[$2])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
-          "\n              LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
-      {
-        "description": "Select aggregates with filters and select alias. The group by aggregate hint should be a no-op.",
-        "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ SUM(a.col3) as sum, COUNT(*) as count FROM a WHERE a.col3 >= 0 AND a.col2 = 'queen'",
-        "output": [
-          "Execution Plan",
-          "\nLogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
+          "\nLogicalProject(avg=[/(CAST(CASE(=($1, 0), null:INTEGER, $0)):DOUBLE, $1)], count=[$1])",
           "\n  LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
           "\n    PinotLogicalExchange(distribution=[hash])",
           "\n      LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
           "\n        LogicalProject(col2=[$1], col3=[$2])",
-          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'queen'))])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
           "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
       {
         "description": "Select aggregates with filters and select alias. The group by aggregate hint should be a no-op.",
-        "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ KURTOSIS(a.col3) as kurtosis, SKEWNESS(a.col6) as skewness, COUNT(DISTINCT a.col6), COUNT(DISTINCT a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'metallica'",
+        "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ SUM(a.col3) as sum, COUNT(*) as count FROM a WHERE a.col3 >= 0 AND a.col2 = 'pink floyd'",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{}], agg#0=[KURTOSIS($0)], agg#1=[SKEWNESS($1)], agg#2=[DISTINCTCOUNT(DISTINCT $2)], agg#3=[DISTINCTCOUNT(DISTINCT $3)])",
-          "\n  LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($0)], agg#1=[FOURTHMOMENT($1)], agg#2=[DISTINCTCOUNT(DISTINCT $2)], agg#3=[DISTINCTCOUNT(DISTINCT $3)])",
+          "\nLogicalProject(sum=[CASE(=($1, 0), null:INTEGER, $0)], count=[$1])",
+          "\n  LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($1)])",
           "\n    PinotLogicalExchange(distribution=[hash])",
-          "\n      LogicalAggregate(group=[{}], agg#0=[FOURTHMOMENT($1)], agg#1=[FOURTHMOMENT($2)], agg#2=[DISTINCTCOUNT(DISTINCT $2)], agg#3=[DISTINCTCOUNT(DISTINCT $1)])",
-          "\n        LogicalProject(col2=[$1], col3=[$2], col6=[$5])",
-          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'metallica'))])",
+          "\n      LogicalAggregate(group=[{}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
+          "\n        LogicalProject(col2=[$1], col3=[$2])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'pink floyd'))])",
           "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
diff --git a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
index 3c586b369d..0b19599217 100644
--- a/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/GroupByPlans.json
@@ -7,55 +7,9 @@
         "output": [
           "Execution Plan",
           "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n  LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n    PinotLogicalExchange(distribution=[hash[0]])",
-          "\n      LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
-          "\n        LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
-      {
-        "description": "Group by with select and aggregates with filters and select alias",
-        "sql": "EXPLAIN PLAN FOR SELECT a.col2, KURTOSIS(a.col3) as kurtosis, DISTINCTCOUNT(a.col1) as dcount, MIN(a.col6), BOOL_AND(a.col5) FROM a WHERE a.col3 >= 0 AND a.col2 = 'linkin park' GROUP BY a.col2",
-        "output": [
-          "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], agg#0=[KURTOSIS($1)], agg#1=[DISTINCTCOUNT($2)], agg#2=[MIN($3)], agg#3=[BOOLAND($4)])",
-          "\n  LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[DISTINCTCOUNT($2)], agg#2=[MIN($3)], agg#3=[BOOLAND($4)])",
-          "\n    PinotLogicalExchange(distribution=[hash[0]])",
-          "\n      LogicalAggregate(group=[{1}], agg#0=[FOURTHMOMENT($2)], agg#1=[DISTINCTCOUNT($0)], agg#2=[MIN($4)], agg#3=[BOOLAND($3)])",
-          "\n        LogicalProject(col1=[$0], col2=[$1], col3=[$2], col5=[$4], col6=[$5])",
-          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'linkin park'))])",
-          "\n            LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
-      {
-        "description": "Group by with elect and aggregates for COUNT variations",
-        "sql": "EXPLAIN PLAN FOR SELECT a.col1, a.col2, DISTINCTCOUNT(a.col3), COUNT(DISTINCT a.col3), COUNT(a.col3), COUNT(*), DISTINCTCOUNT(a.col5) FROM a GROUP BY a.col1, a.col2",
-        "output": [
-          "Execution Plan",
-          "\nLogicalProject(col1=[$0], col2=[$1], EXPR$2=[$2], EXPR$3=[$3], EXPR$4=[$4], EXPR$5=[$4], EXPR$6=[$5])",
-          "\n  LogicalAggregate(group=[{0, 1}], agg#0=[DISTINCTCOUNT($2)], agg#1=[DISTINCTCOUNT(DISTINCT $3)], agg#2=[COUNT($4)], agg#3=[DISTINCTCOUNT($5)])",
-          "\n    LogicalAggregate(group=[{0, 1}], agg#0=[DISTINCTCOUNT($2)], agg#1=[DISTINCTCOUNT(DISTINCT $3)], agg#2=[COUNT($4)], agg#3=[DISTINCTCOUNT($5)])",
-          "\n      PinotLogicalExchange(distribution=[hash[0, 1]])",
-          "\n        LogicalAggregate(group=[{0, 1}], agg#0=[DISTINCTCOUNT($2)], agg#1=[DISTINCTCOUNT(DISTINCT $2)], agg#2=[COUNT()], agg#3=[DISTINCTCOUNT($4)])",
-          "\n          LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
-      {
-        "description": "Group by with select and aggregates with filters",
-        "sql": "EXPLAIN PLAN FOR SELECT a.col5, SKEWNESS(a.col3), AVG(a.col3), BOOL_OR(a.col5), SUM(a.col3), AVG(a.col6), MAX(a.col6) FROM a WHERE a.col3 >= 0 AND a.col2 = 'rolling stones' GROUP BY a.col5",
-        "output": [
-          "Execution Plan",
-          "\nLogicalProject(col5=[$0], EXPR$1=[$1], EXPR$2=[AVG_REDUCE($2, $3)], EXPR$3=[$4], EXPR$4=[$2], EXPR$5=[AVG_REDUCE($5, $3)], EXPR$6=[$6])",
-          "\n  LogicalAggregate(group=[{0}], agg#0=[SKEWNESS($1)], agg#1=[$SUM0($2)], agg#2=[COUNT($3)], agg#3=[BOOLOR($4)], agg#4=[$SUM0($5)], agg#5=[MAX($6)])",
-          "\n    LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[$SUM0($2)], agg#2=[COUNT($3)], agg#3=[BOOLOR($4)], agg#4=[$SUM0($5)], agg#5=[MAX($6)])",
-          "\n      PinotLogicalExchange(distribution=[hash[0]])",
-          "\n        LogicalAggregate(group=[{2}], agg#0=[FOURTHMOMENT($1)], agg#1=[$SUM0($1)], agg#2=[COUNT()], agg#3=[BOOLOR($2)], agg#4=[$SUM0($3)], agg#5=[MAX($3)])",
-          "\n          LogicalProject(col2=[$1], col3=[$2], col5=[$4], col6=[$5])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'rolling stones'))])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n  PinotLogicalExchange(distribution=[hash[0]])",
+          "\n    LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
+          "\n      LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -64,12 +18,11 @@
         "sql": "EXPLAIN PLAN FOR SELECT a.col1, SUM(a.col3), AVG(a.col3), MAX(a.col3), MIN(a.col3) FROM a GROUP BY a.col1",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[AVG_REDUCE($1, $2)], EXPR$3=[$3], EXPR$4=[$4])",
+          "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[/(CAST($1):DOUBLE NOT NULL, $2)], EXPR$3=[$3], EXPR$4=[$4])",
           "\n  LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
-          "\n    LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
-          "\n      PinotLogicalExchange(distribution=[hash[0]])",
-          "\n        LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()], agg#2=[MAX($2)], agg#3=[MIN($2)])",
-          "\n          LogicalTableScan(table=[[a]])",
+          "\n    PinotLogicalExchange(distribution=[hash[0]])",
+          "\n      LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()], agg#2=[MAX($2)], agg#3=[MIN($2)])",
+          "\n        LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -79,12 +32,11 @@
         "output": [
           "Execution Plan",
           "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n  LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n    PinotLogicalExchange(distribution=[hash[0]])",
-          "\n      LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
-          "\n        LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
-          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
-          "\n            LogicalTableScan(table=[[a]])",
+          "\n  PinotLogicalExchange(distribution=[hash[0]])",
+          "\n    LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
+          "\n      LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n        LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+          "\n          LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -95,12 +47,11 @@
         "output": [
           "Execution Plan",
           "\nLogicalAggregate(group=[{0}], agg#0=[COUNT($1)])",
-          "\n  LogicalAggregate(group=[{0}], agg#0=[COUNT($1)])",
-          "\n    PinotLogicalExchange(distribution=[hash[0]])",
-          "\n      LogicalAggregate(group=[{0}], agg#0=[COUNT()])",
-          "\n        LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
-          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
-          "\n            LogicalTableScan(table=[[a]])",
+          "\n  PinotLogicalExchange(distribution=[hash[0]])",
+          "\n    LogicalAggregate(group=[{0}], agg#0=[COUNT()])",
+          "\n      LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n        LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+          "\n          LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -111,12 +62,11 @@
           "Execution Plan",
           "\nLogicalProject(col2=[$1], col1=[$0], EXPR$2=[$2])",
           "\n  LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
-          "\n    LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
-          "\n      PinotLogicalExchange(distribution=[hash[0, 1]])",
-          "\n        LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
-          "\n          LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($0, 'a'))])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n    PinotLogicalExchange(distribution=[hash[0, 1]])",
+          "\n      LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
+          "\n        LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($0, 'a'))])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -126,31 +76,13 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2])",
-          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(AVG_REDUCE($2, $1), 5))])",
+          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
           "\n    LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
-          "\n      LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
-          "\n        PinotLogicalExchange(distribution=[hash[0]])",
-          "\n          LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
-          "\n            LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
-          "\n              LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
-          "\n                LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
-      {
-        "description": "Group by with having clause and different aggregation functions",
-        "sql": "EXPLAIN PLAN FOR SELECT a.col1, KURTOSIS(a.col3), BOOL_OR(a.col5) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1 HAVING COUNT(*) > 10 AND MAX(a.col3) >= 0 AND MIN(a.col3) < 20 AND SKEWNESS(a.col3) <= 10 AND AVG(a.col3) = 5",
-        "output": [
-          "Execution Plan",
-          "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2])",
-          "\n  LogicalFilter(condition=[AND(>($3, 10), >=($4, 0), <($5, 20), <=($6, 10), =(AVG_REDUCE($7, $3), 5))])",
-          "\n    LogicalAggregate(group=[{0}], agg#0=[KURTOSIS($1)], agg#1=[BOOLOR($2)], agg#2=[COUNT($3)], agg#3=[MAX($4)], agg#4=[MIN($5)], agg#5=[SKEWNESS($6)], agg#6=[$SUM0($7)])",
-          "\n      LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[BOOLOR($2)], agg#2=[COUNT($3)], agg#3=[MAX($4)], agg#4=[MIN($5)], agg#5=[FOURTHMOMENT($6)], agg#6=[$SUM0($7)])",
-          "\n        PinotLogicalExchange(distribution=[hash[0]])",
-          "\n          LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($2)], agg#1=[BOOLOR($3)], agg#2=[COUNT()], agg#3=[MAX($2)], agg#4=[MIN($2)], agg#5=[FOURTHMOMENT($2)], agg#6=[$SUM0($2)])",
-          "\n            LogicalProject(col1=[$0], col2=[$1], col3=[$2], col5=[$4])",
-          "\n              LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
-          "\n                LogicalTableScan(table=[[a]])",
+          "\n      PinotLogicalExchange(distribution=[hash[0]])",
+          "\n        LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
+          "\n          LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+          "\n              LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -160,72 +92,25 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(value1=[$0], count=[$1], SUM=[$2])",
-          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(AVG_REDUCE($2, $1), 5))])",
+          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
           "\n    LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
-          "\n      LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
-          "\n        PinotLogicalExchange(distribution=[hash[0]])",
-          "\n          LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
-          "\n            LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
-          "\n              LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
-          "\n                LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
-      {
-        "description": "Group by aggregation kurtosis and skewness with limit",
-        "sql": "EXPLAIN PLAN FOR SELECT a.col6, SKEWNESS(a.col3), KURTOSIS(a.col3) FROM a GROUP BY a.col6 LIMIT 100",
-        "output": [
-          "Execution Plan",
-          "\nLogicalSort(offset=[0], fetch=[100])",
-          "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[]], isSortOnSender=[false], isSortOnReceiver=[false])",
-          "\n    LogicalSort(fetch=[100])",
-          "\n      LogicalAggregate(group=[{0}], agg#0=[SKEWNESS($1)], agg#1=[KURTOSIS($2)])",
-          "\n        LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[FOURTHMOMENT($2)])",
-          "\n          PinotLogicalExchange(distribution=[hash[0]])",
-          "\n            LogicalAggregate(group=[{5}], agg#0=[FOURTHMOMENT($2)], agg#1=[FOURTHMOMENT($2)])",
+          "\n      PinotLogicalExchange(distribution=[hash[0]])",
+          "\n        LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
+          "\n          LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
           "\n              LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
-      {
-        "description": "Group by transform inside aggregation and transform in group by",
-        "sql": "EXPLAIN PLAN FOR SELECT REVERSE(a.col1), SKEWNESS(ADD(a.col3, a.col6)), DISTINCTCOUNT(CONCAT(a.col1, '-', 'a.col2')) FROM a GROUP BY REVERSE(a.col1)",
-        "output": [
-          "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], agg#0=[SKEWNESS($1)], agg#1=[DISTINCTCOUNT($2)])",
-          "\n  LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[DISTINCTCOUNT($2)])",
-          "\n    PinotLogicalExchange(distribution=[hash[0]])",
-          "\n      LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[DISTINCTCOUNT($2)])",
-          "\n        LogicalProject(EXPR$0=[REVERSE($0)], $f1=[ADD($2, $5)], $f2=[CONCAT($0, '-', 'a.col2')])",
-          "\n          LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
-      {
-        "description": "Group by transform inside aggregation and transform in group by with filter with transform",
-        "sql": "EXPLAIN PLAN FOR SELECT REVERSE(a.col1), SKEWNESS(ADD(a.col3, a.col6)), DISTINCTCOUNT(CONCAT(a.col1, '-', 'a.col2')) FROM a WHERE a.col3 >= 42 AND SUBSTR(a.col2, 0, 4) != 'prod' GROUP BY REVERSE(a.col1)",
-        "output": [
-          "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], agg#0=[SKEWNESS($1)], agg#1=[DISTINCTCOUNT($2)])",
-          "\n  LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[DISTINCTCOUNT($2)])",
-          "\n    PinotLogicalExchange(distribution=[hash[0]])",
-          "\n      LogicalAggregate(group=[{0}], agg#0=[FOURTHMOMENT($1)], agg#1=[DISTINCTCOUNT($2)])",
-          "\n        LogicalProject(EXPR$0=[REVERSE($0)], $f1=[ADD($2, $5)], $f2=[CONCAT($0, '-', 'a.col2')])",
-          "\n          LogicalFilter(condition=[AND(>=($2, 42), <>(SUBSTR($1, 0, 4), 'prod'))])",
-          "\n            LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
       {
         "description": "SQL hint based group by optimization with select and aggregate column",
         "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3) FROM a GROUP BY a.col1",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n  LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n    PinotLogicalExchange(distribution=[hash[0]])",
-          "\n      LogicalProject(col1=[$0], col3=[$2])",
-          "\n        LogicalTableScan(table=[[a]])",
+          "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+          "\n  PinotLogicalExchange(distribution=[hash[0]])",
+          "\n    LogicalProject(col1=[$0], col3=[$2])",
+          "\n      LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -234,12 +119,11 @@
         "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, AVG(a.col3) FROM a GROUP BY a.col1",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(col1=[$0], EXPR$1=[AVG_REDUCE($1, $2)])",
-          "\n  LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n    LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
-          "\n      PinotLogicalExchange(distribution=[hash[0]])",
-          "\n        LogicalProject(col1=[$0], col3=[$2])",
-          "\n          LogicalTableScan(table=[[a]])",
+          "\nLogicalProject(col1=[$0], EXPR$1=[/(CAST($1):DOUBLE NOT NULL, $2)])",
+          "\n  LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()])",
+          "\n    PinotLogicalExchange(distribution=[hash[0]])",
+          "\n      LogicalProject(col1=[$0], col3=[$2])",
+          "\n        LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -248,12 +132,11 @@
         "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3), AVG(a.col3), MAX(a.col3), MIN(a.col3) FROM a GROUP BY a.col1",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[AVG_REDUCE($1, $2)], EXPR$3=[$3], EXPR$4=[$4])",
-          "\n  LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
-          "\n    LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT()], agg#2=[MAX($1)], agg#3=[MIN($1)])",
-          "\n      PinotLogicalExchange(distribution=[hash[0]])",
-          "\n        LogicalProject(col1=[$0], col3=[$2])",
-          "\n          LogicalTableScan(table=[[a]])",
+          "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[/(CAST($1):DOUBLE NOT NULL, $2)], EXPR$3=[$3], EXPR$4=[$4])",
+          "\n  LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], agg#1=[COUNT()], EXPR$3=[MAX($1)], EXPR$4=[MIN($1)])",
+          "\n    PinotLogicalExchange(distribution=[hash[0]])",
+          "\n      LogicalProject(col1=[$0], col3=[$2])",
+          "\n        LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -262,12 +145,11 @@
         "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n  LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
-          "\n    PinotLogicalExchange(distribution=[hash[0]])",
-          "\n      LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
-          "\n        LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
-          "\n          LogicalTableScan(table=[[a]])",
+          "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($2)])",
+          "\n  PinotLogicalExchange(distribution=[hash[0]])",
+          "\n    LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n      LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+          "\n        LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -276,12 +158,11 @@
         "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ a.col1, SUM(a.col3), MAX(a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'a' GROUP BY a.col1",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[MAX($2)])",
-          "\n  LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[MAX($2)])",
-          "\n    PinotLogicalExchange(distribution=[hash[0]])",
-          "\n      LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
-          "\n        LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
-          "\n          LogicalTableScan(table=[[a]])",
+          "\nLogicalAggregate(group=[{0}], EXPR$1=[$SUM0($2)], EXPR$2=[MAX($2)])",
+          "\n  PinotLogicalExchange(distribution=[hash[0]])",
+          "\n    LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n      LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+          "\n        LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -291,12 +172,11 @@
         "notes": "TODO: Needs follow up. Project should only keep a.col1 since the other columns are pushed to the filter, but it currently keeps them all",
         "output": [
           "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], agg#0=[COUNT($1)])",
-          "\n  LogicalAggregate(group=[{0}], agg#0=[COUNT()])",
-          "\n    PinotLogicalExchange(distribution=[hash[0]])",
-          "\n      LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
-          "\n        LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
-          "\n          LogicalTableScan(table=[[a]])",
+          "\nLogicalAggregate(group=[{0}], EXPR$1=[COUNT()])",
+          "\n  PinotLogicalExchange(distribution=[hash[0]])",
+          "\n    LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n      LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+          "\n        LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -306,12 +186,11 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(col2=[$1], col1=[$0], EXPR$2=[$2])",
-          "\n  LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
-          "\n    LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
-          "\n      PinotLogicalExchange(distribution=[hash[0, 1]])",
-          "\n        LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
-          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($0, 'a'))])",
-          "\n            LogicalTableScan(table=[[a]])",
+          "\n  LogicalAggregate(group=[{0, 1}], EXPR$2=[$SUM0($2)])",
+          "\n    PinotLogicalExchange(distribution=[hash[0, 1]])",
+          "\n      LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n        LogicalFilter(condition=[AND(>=($2, 0), =($0, 'a'))])",
+          "\n          LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -321,13 +200,12 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(col1=[$0], EXPR$1=[$1], EXPR$2=[$2])",
-          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(AVG_REDUCE($2, $1), 5))])",
-          "\n    LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
-          "\n      LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
-          "\n        PinotLogicalExchange(distribution=[hash[0]])",
-          "\n          LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
+          "\n    LogicalAggregate(group=[{0}], EXPR$1=[COUNT()], EXPR$2=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
+          "\n      PinotLogicalExchange(distribution=[hash[0]])",
+          "\n        LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -337,13 +215,12 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(col1=[$0], EXPR$1=[$1])",
-          "\n  LogicalFilter(condition=[AND(>=($2, 0), <($3, 20), <=($1, 10), =(AVG_REDUCE($1, $4), 5))])",
-          "\n    LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[MAX($2)], agg#2=[MIN($3)], agg#3=[COUNT($4)])",
-          "\n      LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[MAX($2)], agg#2=[MIN($2)], agg#3=[COUNT()])",
-          "\n        PinotLogicalExchange(distribution=[hash[0]])",
-          "\n          LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n  LogicalFilter(condition=[AND(>=($2, 0), <($3, 20), <=($1, 10), =(/(CAST($1):DOUBLE NOT NULL, $4), 5))])",
+          "\n    LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($2)], agg#1=[MAX($2)], agg#2=[MIN($2)], agg#3=[COUNT()])",
+          "\n      PinotLogicalExchange(distribution=[hash[0]])",
+          "\n        LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -353,27 +230,12 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(value1=[$0], count=[$1], SUM=[$2])",
-          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(AVG_REDUCE($2, $1), 5))])",
-          "\n    LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[MAX($3)], agg#3=[MIN($4)])",
-          "\n      LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
-          "\n        PinotLogicalExchange(distribution=[hash[0]])",
-          "\n          LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
-          "\n              LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
-      {
-        "description": "SQL hint based group by optimization with select and aggregates with filters and select alias",
-        "sql": "EXPLAIN PLAN FOR SELECT /*+ aggOptions(is_skip_leaf_stage_group_by='true') */ ts, KURTOSIS(a.col3) as kurtosis, SKEWNESS(a.col6) as skewness, COUNT(DISTINCT a.col6), COUNT(DISTINCT a.col3) FROM a WHERE a.col3 >= 0 AND a.col2 = 'metallica' GROUP BY ts",
-        "output": [
-          "Execution Plan",
-          "\nLogicalAggregate(group=[{0}], agg#0=[KURTOSIS($1)], agg#1=[SKEWNESS($2)], agg#2=[DISTINCTCOUNT(DISTINCT $3)], agg#3=[DISTINCTCOUNT(DISTINCT $4)])",
-          "\n  LogicalAggregate(group=[{3}], agg#0=[FOURTHMOMENT($1)], agg#1=[FOURTHMOMENT($2)], agg#2=[DISTINCTCOUNT(DISTINCT $2)], agg#3=[DISTINCTCOUNT(DISTINCT $1)])",
-          "\n    PinotLogicalExchange(distribution=[hash[3]])",
-          "\n      LogicalProject(col2=[$1], col3=[$2], col6=[$5], ts=[$6])",
-          "\n        LogicalFilter(condition=[AND(>=($2, 0), =($1, 'metallica'))])",
-          "\n          LogicalTableScan(table=[[a]])",
+          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($3, 0), <($4, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
+          "\n    LogicalAggregate(group=[{0}], count=[COUNT()], SUM=[$SUM0($2)], agg#2=[MAX($2)], agg#3=[MIN($2)])",
+          "\n      PinotLogicalExchange(distribution=[hash[0]])",
+          "\n        LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       }
diff --git a/pinot-query-planner/src/test/resources/queries/JoinPlans.json b/pinot-query-planner/src/test/resources/queries/JoinPlans.json
index 6d7b76c4e9..03f8c9dbf7 100644
--- a/pinot-query-planner/src/test/resources/queries/JoinPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/JoinPlans.json
@@ -116,20 +116,19 @@
         "sql": "EXPLAIN PLAN FOR SELECT a.col1, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2  WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(col1=[$0], EXPR$1=[AVG_REDUCE($1, $2)])",
+          "\nLogicalProject(col1=[$0], EXPR$1=[/(CAST($1):DOUBLE NOT NULL, $2)])",
           "\n  LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n    LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n      PinotLogicalExchange(distribution=[hash[0]])",
-          "\n        LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
-          "\n          LogicalJoin(condition=[=($0, $1)], joinType=[inner])",
-          "\n            PinotLogicalExchange(distribution=[hash[0]])",
-          "\n              LogicalProject(col1=[$0])",
-          "\n                LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
-          "\n                  LogicalTableScan(table=[[a]])",
-          "\n            PinotLogicalExchange(distribution=[hash[0]])",
-          "\n              LogicalProject(col2=[$1], col3=[$2])",
-          "\n                LogicalFilter(condition=[<($2, 0)])",
-          "\n                  LogicalTableScan(table=[[b]])",
+          "\n    PinotLogicalExchange(distribution=[hash[0]])",
+          "\n      LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+          "\n        LogicalJoin(condition=[=($0, $1)], joinType=[inner])",
+          "\n          PinotLogicalExchange(distribution=[hash[0]])",
+          "\n            LogicalProject(col1=[$0])",
+          "\n              LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+          "\n                LogicalTableScan(table=[[a]])",
+          "\n          PinotLogicalExchange(distribution=[hash[0]])",
+          "\n            LogicalProject(col2=[$1], col3=[$2])",
+          "\n              LogicalFilter(condition=[<($2, 0)])",
+          "\n                LogicalTableScan(table=[[b]])",
           "\n"
         ]
       },
@@ -289,32 +288,26 @@
           "\n                          LogicalFilter(condition=[=($1, 'test')])",
           "\n                            LogicalTableScan(table=[[a]])",
           "\n                      PinotLogicalExchange(distribution=[hash[0]])",
-          "\n                        LogicalProject(col3=[$0], $f1=[CAST($1):BOOLEAN NOT NULL])",
-          "\n                          LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
+          "\n                        LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
+          "\n                          PinotLogicalExchange(distribution=[hash[0]])",
           "\n                            LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
-          "\n                              PinotLogicalExchange(distribution=[hash[0]])",
-          "\n                                LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
-          "\n                                  LogicalProject(col3=[$2], $f1=[true])",
-          "\n                                    LogicalFilter(condition=[=($0, 'foo')])",
-          "\n                                      LogicalTableScan(table=[[b]])",
+          "\n                              LogicalProject(col3=[$2], $f1=[true])",
+          "\n                                LogicalFilter(condition=[=($0, 'foo')])",
+          "\n                                  LogicalTableScan(table=[[b]])",
           "\n              PinotLogicalExchange(distribution=[hash[0]])",
-          "\n                LogicalProject(col3=[$0], $f1=[CAST($1):BOOLEAN NOT NULL])",
-          "\n                  LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
+          "\n                LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
+          "\n                  PinotLogicalExchange(distribution=[hash[0]])",
           "\n                    LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
-          "\n                      PinotLogicalExchange(distribution=[hash[0]])",
-          "\n                        LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
-          "\n                          LogicalProject(col3=[$2], $f1=[true])",
-          "\n                            LogicalFilter(condition=[=($0, 'bar')])",
-          "\n                              LogicalTableScan(table=[[b]])",
+          "\n                      LogicalProject(col3=[$2], $f1=[true])",
+          "\n                        LogicalFilter(condition=[=($0, 'bar')])",
+          "\n                          LogicalTableScan(table=[[b]])",
           "\n      PinotLogicalExchange(distribution=[hash[0]])",
-          "\n        LogicalProject(col3=[$0], $f1=[CAST($1):BOOLEAN NOT NULL])",
-          "\n          LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
+          "\n        LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
+          "\n          PinotLogicalExchange(distribution=[hash[0]])",
           "\n            LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
-          "\n              PinotLogicalExchange(distribution=[hash[0]])",
-          "\n                LogicalAggregate(group=[{0}], agg#0=[MIN($1)])",
-          "\n                  LogicalProject(col3=[$2], $f1=[true])",
-          "\n                    LogicalFilter(condition=[=($0, 'foobar')])",
-          "\n                      LogicalTableScan(table=[[b]])",
+          "\n              LogicalProject(col3=[$2], $f1=[true])",
+          "\n                LogicalFilter(condition=[=($0, 'foobar')])",
+          "\n                  LogicalTableScan(table=[[b]])",
           "\n"
         ]
       },
@@ -328,12 +321,11 @@
           "\n    PinotLogicalExchange(distribution=[hash[0, 1]])",
           "\n      LogicalTableScan(table=[[a]])",
           "\n    PinotLogicalExchange(distribution=[hash[0, 1]])",
-          "\n      LogicalProject(col1=[$0], col2=[$1], col10=[CAST($0):VARCHAR], col20=[CAST($1):VARCHAR], $f2=[CAST($2):DOUBLE], EXPR$3=[*(0.5:DECIMAL(2, 1), $2)])",
+          "\n      LogicalProject(col1=[$0], col2=[$1], col10=[CAST($0):VARCHAR], col20=[CAST($1):VARCHAR], $f2=[CAST($2):INTEGER], EXPR$3=[*(0.5:DECIMAL(2, 1), $2)])",
           "\n        LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
-          "\n          LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
-          "\n            PinotLogicalExchange(distribution=[hash[0, 1]])",
-          "\n              LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
-          "\n                LogicalTableScan(table=[[b]])",
+          "\n          PinotLogicalExchange(distribution=[hash[0, 1]])",
+          "\n            LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
+          "\n              LogicalTableScan(table=[[b]])",
           "\n"
         ]
       },
@@ -349,10 +341,9 @@
           "\n        LogicalTableScan(table=[[a]])",
           "\n    PinotLogicalExchange(distribution=[hash[0]])",
           "\n      LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n        LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n          PinotLogicalExchange(distribution=[hash[0]])",
-          "\n            LogicalAggregate(group=[{1}], agg#0=[$SUM0($2)])",
-          "\n              LogicalTableScan(table=[[b]])",
+          "\n        PinotLogicalExchange(distribution=[hash[0]])",
+          "\n          LogicalAggregate(group=[{1}], agg#0=[$SUM0($2)])",
+          "\n            LogicalTableScan(table=[[b]])",
           "\n"
         ]
       }
diff --git a/pinot-query-planner/src/test/resources/queries/OrderByPlans.json b/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
index f205f0bb89..a6b391ca1c 100644
--- a/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/OrderByPlans.json
@@ -60,10 +60,9 @@
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n    LogicalSort(sort0=[$0], dir0=[ASC])",
           "\n      LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n        LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n          PinotLogicalExchange(distribution=[hash[0]])",
-          "\n            LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n        PinotLogicalExchange(distribution=[hash[0]])",
+          "\n          LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -75,11 +74,10 @@
           "\nLogicalSort(sort0=[$0], dir0=[ASC], offset=[0])",
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n    LogicalSort(sort0=[$0], dir0=[ASC])",
-          "\n      LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n        LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n          PinotLogicalExchange(distribution=[hash[0]])",
-          "\n            LogicalProject(col1=[$0], col3=[$2])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n      LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)])",
+          "\n        PinotLogicalExchange(distribution=[hash[0]])",
+          "\n          LogicalProject(col1=[$0], col3=[$2])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -92,10 +90,9 @@
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n    LogicalSort(sort0=[$0], dir0=[ASC])",
           "\n      LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n        LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n          PinotLogicalExchange(distribution=[hash[0]])",
-          "\n            LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n        PinotLogicalExchange(distribution=[hash[0]])",
+          "\n          LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -107,11 +104,10 @@
           "\nLogicalSort(sort0=[$0], dir0=[ASC], offset=[0])",
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n    LogicalSort(sort0=[$0], dir0=[ASC])",
-          "\n      LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n        LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n          PinotLogicalExchange(distribution=[hash[0]])",
-          "\n            LogicalProject(col1=[$0], col3=[$2])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n      LogicalAggregate(group=[{0}], sum=[$SUM0($1)])",
+          "\n        PinotLogicalExchange(distribution=[hash[0]])",
+          "\n          LogicalProject(col1=[$0], col3=[$2])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       }
diff --git a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
index 08e4b2579a..f5ba9d3182 100644
--- a/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
+++ b/pinot-query-planner/src/test/resources/queries/PinotHintablePlans.json
@@ -20,7 +20,7 @@
         "sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(is_colocated_by_join_keys='true'), aggOptions(is_partitioned_by_group_by_keys='true') */a.col1, AVG(b.col3) FROM a JOIN b ON a.col1 = b.col2  WHERE a.col3 >= 0 AND a.col2 = 'a' AND b.col3 < 0 GROUP BY a.col1",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(col1=[$0], EXPR$1=[AVG_REDUCE($1, $2)])",
+          "\nLogicalProject(col1=[$0], EXPR$1=[/(CAST($1):DOUBLE NOT NULL, $2)])",
           "\n  LogicalAggregate(group=[{0}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
           "\n    LogicalJoin(condition=[=($0, $1)], joinType=[inner])",
           "\n      PinotLogicalExchange(distribution=[single])",
@@ -104,17 +104,16 @@
         "output": [
           "Execution Plan",
           "\nLogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n  LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n    PinotLogicalExchange(distribution=[hash[0]])",
-          "\n      LogicalAggregate(group=[{1}], agg#0=[$SUM0($2)])",
-          "\n        PinotLogicalExchange(distribution=[hash[0]])",
-          "\n          LogicalJoin(condition=[=($0, $3)], joinType=[semi])",
-          "\n            LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
-          "\n              LogicalTableScan(table=[[a]])",
-          "\n            PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])",
-          "\n              LogicalProject(col2=[$1], col3=[$2])",
-          "\n                LogicalFilter(condition=[>($2, 0)])",
-          "\n                  LogicalTableScan(table=[[b]])",
+          "\n  PinotLogicalExchange(distribution=[hash[0]])",
+          "\n    LogicalAggregate(group=[{1}], agg#0=[$SUM0($2)])",
+          "\n      PinotLogicalExchange(distribution=[hash[0]])",
+          "\n        LogicalJoin(condition=[=($0, $3)], joinType=[semi])",
+          "\n          LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n            LogicalTableScan(table=[[a]])",
+          "\n          PinotLogicalExchange(distribution=[broadcast], relExchangeType=[PIPELINE_BREAKER])",
+          "\n            LogicalProject(col2=[$1], col3=[$2])",
+          "\n              LogicalFilter(condition=[>($2, 0)])",
+          "\n                LogicalTableScan(table=[[b]])",
           "\n"
         ]
       },
@@ -124,12 +123,11 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(col2=[$1], col1=[$0], EXPR$2=[$2])",
-          "\n  LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
-          "\n    LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)])",
-          "\n      PinotLogicalExchange(distribution=[hash[0, 1]])",
-          "\n        LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
-          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($0, 'a'))])",
-          "\n            LogicalTableScan(table=[[a]])",
+          "\n  LogicalAggregate(group=[{0, 1}], EXPR$2=[$SUM0($2)])",
+          "\n    PinotLogicalExchange(distribution=[hash[0, 1]])",
+          "\n      LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
+          "\n        LogicalFilter(condition=[AND(>=($2, 0), =($0, 'a'))])",
+          "\n          LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -139,13 +137,12 @@
         "output": [
           "Execution Plan",
           "\nLogicalProject(col2=[$0], EXPR$1=[$1], EXPR$2=[$2], EXPR$3=[$3])",
-          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($4, 0), <($5, 20), <=($2, 10), =(AVG_REDUCE($2, $1), 5))])",
-          "\n    LogicalAggregate(group=[{0}], agg#0=[COUNT($1)], agg#1=[$SUM0($2)], agg#2=[$SUM0($3)], agg#3=[MAX($4)], agg#4=[MIN($5)])",
-          "\n      LogicalAggregate(group=[{0}], agg#0=[COUNT()], agg#1=[$SUM0($1)], agg#2=[$SUM0($2)], agg#3=[MAX($1)], agg#4=[MIN($1)])",
-          "\n        PinotLogicalExchange(distribution=[hash[0]])",
-          "\n          LogicalProject(col2=[$1], col3=[$2], $f2=[CAST($0):DECIMAL(1000, 500) NOT NULL])",
-          "\n            LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n  LogicalFilter(condition=[AND(>($1, 10), >=($4, 0), <($5, 20), <=($2, 10), =(/(CAST($2):DOUBLE NOT NULL, $1), 5))])",
+          "\n    LogicalAggregate(group=[{0}], EXPR$1=[COUNT()], EXPR$2=[$SUM0($1)], EXPR$3=[$SUM0($2)], agg#3=[MAX($1)], agg#4=[MIN($1)])",
+          "\n      PinotLogicalExchange(distribution=[hash[0]])",
+          "\n        LogicalProject(col2=[$1], col3=[$2], $f2=[CAST($0):DECIMAL(1000, 500) NOT NULL])",
+          "\n          LogicalFilter(condition=[AND(>=($2, 0), =($1, 'a'))])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -173,23 +170,6 @@
           "\n      LogicalTableScan(table=[[a]])",
           "\n"
         ]
-      },
-      {
-        "description": "Join and agg options",
-        "sql": "EXPLAIN PLAN FOR SELECT /*+ joinOptions(is_colocated_by_join_keys='true'), aggOptions(is_partitioned_by_group_by_keys='true') */ a.col3, a.col1, SUM(b.col3) FROM a JOIN b ON a.col3 = b.col3 GROUP BY a.col3, a.col1",
-        "output": [
-          "Execution Plan",
-          "\nLogicalProject(col3=[$1], col1=[$0], EXPR$2=[$2])",
-          "\n  LogicalAggregate(group=[{0, 1}], EXPR$2=[$SUM0($2)])",
-          "\n    LogicalJoin(condition=[=($1, $2)], joinType=[inner])",
-          "\n      PinotLogicalExchange(distribution=[single])",
-          "\n        LogicalProject(col1=[$0], col3=[$2])",
-          "\n          LogicalTableScan(table=[[a]])",
-          "\n      PinotLogicalExchange(distribution=[single])",
-          "\n        LogicalProject(col3=[$2])",
-          "\n          LogicalTableScan(table=[[b]])",
-          "\n"
-        ]
       }
     ]
   }
diff --git a/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json b/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json
index ed3f067cac..262a73b52b 100644
--- a/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/WindowFunctionPlans.json
@@ -264,10 +264,9 @@
           "\n  LogicalWindow(window#0=[window(aggs [MIN($0)])])",
           "\n    PinotLogicalExchange(distribution=[hash])",
           "\n      LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n        LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n          PinotLogicalExchange(distribution=[hash[0]])",
-          "\n            LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n        PinotLogicalExchange(distribution=[hash[0]])",
+          "\n          LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -279,12 +278,11 @@
           "\nLogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)])",
           "\n  LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0)])])",
           "\n    PinotLogicalExchange(distribution=[hash])",
-          "\n      LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+          "\n      LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
           "\n        LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n          LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n            PinotLogicalExchange(distribution=[hash[0]])",
-          "\n              LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
-          "\n                LogicalTableScan(table=[[a]])",
+          "\n          PinotLogicalExchange(distribution=[hash[0]])",
+          "\n            LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+          "\n              LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -299,12 +297,11 @@
           "\n      LogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], col3=[$0])",
           "\n        LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0)])])",
           "\n          PinotLogicalExchange(distribution=[hash])",
-          "\n            LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+          "\n            LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
           "\n              LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n                LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n                  PinotLogicalExchange(distribution=[hash[0]])",
-          "\n                    LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
-          "\n                      LogicalTableScan(table=[[a]])",
+          "\n                PinotLogicalExchange(distribution=[hash[0]])",
+          "\n                  LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+          "\n                    LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -319,12 +316,11 @@
           "\n      LogicalProject(EXPR$0=[$1], EXPR$1=[$2], col3=[$0])",
           "\n        LogicalWindow(window#0=[window( rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])",
           "\n          PinotLogicalExchange(distribution=[hash])",
-          "\n            LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+          "\n            LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
           "\n              LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n                LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n                  PinotLogicalExchange(distribution=[hash[0]])",
-          "\n                    LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
-          "\n                      LogicalTableScan(table=[[a]])",
+          "\n                PinotLogicalExchange(distribution=[hash[0]])",
+          "\n                  LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+          "\n                    LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -525,10 +521,9 @@
           "\n  LogicalWindow(window#0=[window(aggs [MIN($0), MAX($0)])])",
           "\n    PinotLogicalExchange(distribution=[hash])",
           "\n      LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n        LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n          PinotLogicalExchange(distribution=[hash[0]])",
-          "\n            LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n        PinotLogicalExchange(distribution=[hash[0]])",
+          "\n          LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -537,15 +532,14 @@
         "sql": "EXPLAIN PLAN FOR SELECT AVG(a.col3), AVG(a.col3) OVER(), SUM(a.col3) OVER() FROM a GROUP BY a.col3",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], EXPR$2=[$4])",
-          "\n  LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0), SUM($0)])])",
+          "\nLogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], EXPR$2=[$2])",
+          "\n  LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0)])])",
           "\n    PinotLogicalExchange(distribution=[hash])",
-          "\n      LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+          "\n      LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
           "\n        LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n          LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n            PinotLogicalExchange(distribution=[hash[0]])",
-          "\n              LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
-          "\n                LogicalTableScan(table=[[a]])",
+          "\n          PinotLogicalExchange(distribution=[hash[0]])",
+          "\n            LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+          "\n              LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -557,15 +551,14 @@
           "\nLogicalSort(sort0=[$3], dir0=[ASC], offset=[0])",
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[3]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n    LogicalSort(sort0=[$3], dir0=[ASC])",
-          "\n      LogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], EXPR$2=[$4], col3=[$0])",
-          "\n        LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0), SUM($0)])])",
+          "\n      LogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], EXPR$2=[$2], col3=[$0])",
+          "\n        LogicalWindow(window#0=[window(aggs [SUM($0), COUNT($0)])])",
           "\n          PinotLogicalExchange(distribution=[hash])",
-          "\n            LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+          "\n            LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
           "\n              LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n                LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n                  PinotLogicalExchange(distribution=[hash[0]])",
-          "\n                    LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
-          "\n                      LogicalTableScan(table=[[a]])",
+          "\n                PinotLogicalExchange(distribution=[hash[0]])",
+          "\n                  LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+          "\n                    LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -861,10 +854,9 @@
           "\n  LogicalWindow(window#0=[window(partition {0} aggs [MIN($0)])])",
           "\n    PinotLogicalExchange(distribution=[hash[0]])",
           "\n      LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n        LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n          PinotLogicalExchange(distribution=[hash[0]])",
-          "\n            LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n        PinotLogicalExchange(distribution=[hash[0]])",
+          "\n          LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -876,12 +868,11 @@
           "\nLogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)])",
           "\n  LogicalWindow(window#0=[window(partition {0} aggs [SUM($0), COUNT($0)])])",
           "\n    PinotLogicalExchange(distribution=[hash[0]])",
-          "\n      LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+          "\n      LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
           "\n        LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n          LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n            PinotLogicalExchange(distribution=[hash[0]])",
-          "\n              LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
-          "\n                LogicalTableScan(table=[[a]])",
+          "\n          PinotLogicalExchange(distribution=[hash[0]])",
+          "\n            LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+          "\n              LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -896,12 +887,11 @@
           "\n      LogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], col3=[$0])",
           "\n        LogicalWindow(window#0=[window(partition {0} aggs [SUM($0), COUNT($0)])])",
           "\n          PinotLogicalExchange(distribution=[hash[0]])",
-          "\n            LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+          "\n            LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
           "\n              LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n                LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n                  PinotLogicalExchange(distribution=[hash[0]])",
-          "\n                    LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
-          "\n                      LogicalTableScan(table=[[a]])",
+          "\n                PinotLogicalExchange(distribution=[hash[0]])",
+          "\n                  LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+          "\n                    LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -1044,8 +1034,8 @@
           "\nLogicalSort(sort0=[$0], dir0=[ASC], offset=[0])",
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[0]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n    LogicalSort(sort0=[$0], dir0=[ASC])",
-          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)])",
-          "\n        LogicalWindow(window#0=[window(partition {0, 1} aggs [SUM($2), SUM($2), COUNT($2)])])",
+          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)])",
+          "\n        LogicalWindow(window#0=[window(partition {0, 1} aggs [SUM($2), COUNT($2)])])",
           "\n          PinotLogicalExchange(distribution=[hash[0, 1]])",
           "\n            LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
           "\n              LogicalTableScan(table=[[a]])",
@@ -1093,8 +1083,8 @@
           "\nLogicalSort(sort0=[$3], dir0=[ASC], offset=[0])",
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[3]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n    LogicalSort(sort0=[$3], dir0=[ASC])",
-          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)], col2=[$1])",
-          "\n        LogicalWindow(window#0=[window(partition {0, 1} aggs [SUM($2), SUM($2), COUNT($2)])])",
+          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)], col2=[$1])",
+          "\n        LogicalWindow(window#0=[window(partition {0, 1} aggs [SUM($2), COUNT($2)])])",
           "\n          PinotLogicalExchange(distribution=[hash[0, 1]])",
           "\n            LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
           "\n              LogicalTableScan(table=[[a]])",
@@ -1194,10 +1184,9 @@
           "\n  LogicalWindow(window#0=[window(partition {0} aggs [MIN($0), SUM($0)])])",
           "\n    PinotLogicalExchange(distribution=[hash[0]])",
           "\n      LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n        LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)])",
-          "\n          PinotLogicalExchange(distribution=[hash[0]])",
-          "\n            LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n        PinotLogicalExchange(distribution=[hash[0]])",
+          "\n          LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -1209,12 +1198,11 @@
           "\nLogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], EXPR$2=[$4])",
           "\n  LogicalWindow(window#0=[window(partition {0} aggs [SUM($0), COUNT($0), MAX($0)])])",
           "\n    PinotLogicalExchange(distribution=[hash[0]])",
-          "\n      LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+          "\n      LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
           "\n        LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n          LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n            PinotLogicalExchange(distribution=[hash[0]])",
-          "\n              LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
-          "\n                LogicalTableScan(table=[[a]])",
+          "\n          PinotLogicalExchange(distribution=[hash[0]])",
+          "\n            LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+          "\n              LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -1229,12 +1217,11 @@
           "\n      LogicalProject(EXPR$0=[$1], EXPR$1=[/(CAST($2):DOUBLE NOT NULL, $3)], EXPR$2=[$4], col3=[$0])",
           "\n        LogicalWindow(window#0=[window(partition {0} aggs [SUM($0), COUNT($0), MAX($0)])])",
           "\n          PinotLogicalExchange(distribution=[hash[0]])",
-          "\n            LogicalProject(col3=[$0], EXPR$0=[AVG_REDUCE($1, $2)])",
+          "\n            LogicalProject(col3=[$0], EXPR$0=[/(CAST($1):DOUBLE NOT NULL, $2)])",
           "\n              LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n                LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($2)])",
-          "\n                  PinotLogicalExchange(distribution=[hash[0]])",
-          "\n                    LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
-          "\n                      LogicalTableScan(table=[[a]])",
+          "\n                PinotLogicalExchange(distribution=[hash[0]])",
+          "\n                  LogicalAggregate(group=[{2}], agg#0=[$SUM0($2)], agg#1=[COUNT()])",
+          "\n                    LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -1839,8 +1826,8 @@
           "\nLogicalSort(sort0=[$0], dir0=[DESC], offset=[0])",
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[0 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n    LogicalSort(sort0=[$0], dir0=[DESC])",
-          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)])",
-          "\n        LogicalWindow(window#0=[window(order by [1, 0 DESC] aggs [SUM($2), SUM($2), COUNT($2)])])",
+          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)])",
+          "\n        LogicalWindow(window#0=[window(order by [1, 0 DESC] aggs [SUM($2), COUNT($2)])])",
           "\n          PinotLogicalSortExchange(distribution=[hash], collation=[[1, 0 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n            LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
           "\n              LogicalTableScan(table=[[a]])",
@@ -1887,8 +1874,8 @@
           "\nLogicalSort(sort0=[$0], dir0=[DESC], offset=[0], fetch=[10])",
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[0 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n    LogicalSort(sort0=[$0], dir0=[DESC], fetch=[10])",
-          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)])",
-          "\n        LogicalWindow(window#0=[window(order by [1, 0 DESC] aggs [SUM($2), SUM($2), COUNT($2)])])",
+          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)])",
+          "\n        LogicalWindow(window#0=[window(order by [1, 0 DESC] aggs [SUM($2), COUNT($2)])])",
           "\n          PinotLogicalSortExchange(distribution=[hash], collation=[[1, 0 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n            LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
           "\n              LogicalTableScan(table=[[a]])",
@@ -2519,8 +2506,8 @@
           "\nLogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[ASC], offset=[0])",
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[3, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n    LogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[ASC])",
-          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)], col2=[$1])",
-          "\n        LogicalWindow(window#0=[window(partition {0, 1} order by [1, 0] aggs [SUM($2), SUM($2), COUNT($2)])])",
+          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)], col2=[$1])",
+          "\n        LogicalWindow(window#0=[window(partition {0, 1} order by [1, 0] aggs [SUM($2), COUNT($2)])])",
           "\n          PinotLogicalExchange(distribution=[hash[0, 1]])",
           "\n            LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
           "\n              LogicalTableScan(table=[[a]])",
@@ -2567,8 +2554,8 @@
           "\nLogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[ASC], offset=[0], fetch=[10])",
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[3, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n    LogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[ASC], fetch=[10])",
-          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)], col2=[$1])",
-          "\n        LogicalWindow(window#0=[window(partition {0, 1} order by [1, 0] aggs [SUM($2), SUM($2), COUNT($2)])])",
+          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)], col2=[$1])",
+          "\n        LogicalWindow(window#0=[window(partition {0, 1} order by [1, 0] aggs [SUM($2), COUNT($2)])])",
           "\n          PinotLogicalExchange(distribution=[hash[0, 1]])",
           "\n            LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
           "\n              LogicalTableScan(table=[[a]])",
@@ -3186,8 +3173,8 @@
           "\nLogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[DESC], offset=[0])",
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[3, 0 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n    LogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[DESC])",
-          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)], col2=[$1])",
-          "\n        LogicalWindow(window#0=[window(partition {0, 1} order by [2, 0] aggs [SUM($2), SUM($2), COUNT($2)])])",
+          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)], col2=[$1])",
+          "\n        LogicalWindow(window#0=[window(partition {0, 1} order by [2, 0] aggs [SUM($2), COUNT($2)])])",
           "\n          PinotLogicalSortExchange(distribution=[hash[0, 1]], collation=[[2, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n            LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
           "\n              LogicalTableScan(table=[[a]])",
@@ -3234,8 +3221,8 @@
           "\nLogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[DESC], offset=[0], fetch=[10])",
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[3, 0 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n    LogicalSort(sort0=[$3], sort1=[$0], dir0=[ASC], dir1=[DESC], fetch=[10])",
-          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($4):DOUBLE NOT NULL, $5)], col2=[$1])",
-          "\n        LogicalWindow(window#0=[window(partition {0, 1} order by [2, 0] aggs [SUM($2), SUM($2), COUNT($2)])])",
+          "\n      LogicalProject(col1=[$0], EXPR$1=[$3], EXPR$2=[/(CAST($3):DOUBLE NOT NULL, $4)], col2=[$1])",
+          "\n        LogicalWindow(window#0=[window(partition {0, 1} order by [2, 0] aggs [SUM($2), COUNT($2)])])",
           "\n          PinotLogicalSortExchange(distribution=[hash[0, 1]], collation=[[2, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n            LogicalProject(col1=[$0], col2=[$1], col3=[$2])",
           "\n              LogicalTableScan(table=[[a]])",
@@ -3395,10 +3382,9 @@
           "\n  LogicalWindow(window#0=[window(order by [2 DESC, 0] aggs [SUM($1), COUNT($1)])])",
           "\n    PinotLogicalSortExchange(distribution=[hash], collation=[[2 DESC, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n      LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
-          "\n        LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
-          "\n          PinotLogicalExchange(distribution=[hash[0, 1]])",
-          "\n            LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n        PinotLogicalExchange(distribution=[hash[0, 1]])",
+          "\n          LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -3411,10 +3397,9 @@
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[1 DESC, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n    LogicalProject(col1=[$0], EXPR$1=[$2])",
           "\n      LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
-          "\n        LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
-          "\n          PinotLogicalExchange(distribution=[hash[0, 1]])",
-          "\n            LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n        PinotLogicalExchange(distribution=[hash[0, 1]])",
+          "\n          LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -3427,10 +3412,9 @@
           "\n  PinotLogicalSortExchange(distribution=[hash], collation=[[1 DESC, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n    LogicalProject(col1=[$0], EXPR$1=[$2])",
           "\n      LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
-          "\n        LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
-          "\n          PinotLogicalExchange(distribution=[hash[0, 1]])",
-          "\n            LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n        PinotLogicalExchange(distribution=[hash[0, 1]])",
+          "\n          LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -3443,10 +3427,9 @@
           "\n  LogicalWindow(window#0=[window(partition {0} order by [2 DESC, 0] aggs [MAX($1)])])",
           "\n    PinotLogicalSortExchange(distribution=[hash[0]], collation=[[2 DESC, 0]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n      LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
-          "\n        LogicalAggregate(group=[{0, 1}], agg#0=[COUNT($2)])",
-          "\n          PinotLogicalExchange(distribution=[hash[0, 1]])",
-          "\n            LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n        PinotLogicalExchange(distribution=[hash[0, 1]])",
+          "\n          LogicalAggregate(group=[{0, 2}], agg#0=[COUNT()])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -3478,24 +3461,6 @@
           "\n"
         ]
       },
-      {
-        "description": "Window function CTE: row_number WITH statement having OVER with PARTITION BY ORDER BY and aggregations and transforms",
-        "sql": "EXPLAIN PLAN FOR WITH windowfunc AS (SELECT REVERSE(a.col1) as rev, ROW_NUMBER() OVER(PARTITION BY a.col2 ORDER BY a.col3) as rownum, a.col6 from a) SELECT rev, a.rownum, MAX(a.col6) FROM windowfunc AS a where a.rownum < 5 GROUP BY rev, a.rownum",
-        "output": [
-          "Execution Plan",
-          "\nLogicalAggregate(group=[{0, 1}], agg#0=[MAX($2)])",
-          "\n  LogicalAggregate(group=[{0, 1}], agg#0=[MAX($2)])",
-          "\n    PinotLogicalExchange(distribution=[hash[0, 1]])",
-          "\n      LogicalAggregate(group=[{0, 1}], agg#0=[MAX($2)])",
-          "\n        LogicalProject(rev=[$3], rownum=[$4], col6=[$2])",
-          "\n          LogicalFilter(condition=[<($4, 5)])",
-          "\n            LogicalWindow(window#0=[window(partition {0} order by [1] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])",
-          "\n              PinotLogicalSortExchange(distribution=[hash[0]], collation=[[1]], isSortOnSender=[false], isSortOnReceiver=[true])",
-          "\n                LogicalProject(col2=[$1], col3=[$2], col6=[$5], $3=[REVERSE($0)])",
-          "\n                  LogicalTableScan(table=[[a]])",
-          "\n"
-        ]
-      },
       {
         "description": "Window function subquery: row_number having OVER with PARTITION BY ORDER BY",
         "sql": "EXPLAIN PLAN FOR SELECT row_number, col2, col3 FROM (SELECT ROW_NUMBER() OVER(PARTITION BY a.col2 ORDER BY a.col3 DESC) as row_number, a.col2, a.col3 FROM a) WHERE row_number <= 10",
@@ -3519,10 +3484,9 @@
           "\n  LogicalWindow(window#0=[window(order by [1 DESC] aggs [RANK()])])",
           "\n    PinotLogicalSortExchange(distribution=[hash], collation=[[1 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
           "\n      LogicalAggregate(group=[{0}], agg#0=[COUNT($1)])",
-          "\n        LogicalAggregate(group=[{0}], agg#0=[COUNT($1)])",
-          "\n          PinotLogicalExchange(distribution=[hash[0]])",
-          "\n            LogicalAggregate(group=[{0}], agg#0=[COUNT()])",
-          "\n              LogicalTableScan(table=[[a]])",
+          "\n        PinotLogicalExchange(distribution=[hash[0]])",
+          "\n          LogicalAggregate(group=[{0}], agg#0=[COUNT()])",
+          "\n            LogicalTableScan(table=[[a]])",
           "\n"
         ]
       },
@@ -3569,32 +3533,6 @@
           "\n                  LogicalTableScan(table=[[b]])",
           "\n"
         ]
-      },
-      {
-        "description": "Window function subquery with join using row_number and aggregation",
-        "sql": "EXPLAIN PLAN FOR SELECT row_number, col2, SUM(col3), KURTOSIS(MULT(col3, col6)) FROM (SELECT a.col2 as col2, a.col3 as col3, a.col6 as col6, ROW_NUMBER() OVER(PARTITION BY a.col2 ORDER BY a.col3 DESC) as row_number FROM a INNER JOIN b ON a.col1 = b.col2 WHERE a.col3 > 100 AND b.col1 IN ('douglas adams', 'brandon sanderson')) where row_number <= 3 GROUP BY row_number, col2",
-        "output": [
-          "Execution Plan",
-          "\nLogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[KURTOSIS($3)])",
-          "\n  LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[FOURTHMOMENT($3)])",
-          "\n    PinotLogicalExchange(distribution=[hash[0, 1]])",
-          "\n      LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[FOURTHMOMENT($3)])",
-          "\n        LogicalProject(row_number=[$3], col2=[$0], col3=[$1], $f3=[MULT($1, $2)])",
-          "\n          LogicalFilter(condition=[<=($3, 3)])",
-          "\n            LogicalWindow(window#0=[window(partition {0} order by [1 DESC] rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])])",
-          "\n              PinotLogicalSortExchange(distribution=[hash[0]], collation=[[1 DESC]], isSortOnSender=[false], isSortOnReceiver=[true])",
-          "\n                LogicalProject(col2=[$1], col3=[$2], col6=[$3])",
-          "\n                  LogicalJoin(condition=[=($0, $4)], joinType=[inner])",
-          "\n                    PinotLogicalExchange(distribution=[hash[0]])",
-          "\n                      LogicalProject(col1=[$0], col2=[$1], col3=[$2], col6=[$5])",
-          "\n                        LogicalFilter(condition=[>($2, 100)])",
-          "\n                          LogicalTableScan(table=[[a]])",
-          "\n                    PinotLogicalExchange(distribution=[hash[0]])",
-          "\n                      LogicalProject(col2=[$1])",
-          "\n                        LogicalFilter(condition=[OR(=($0, 'brandon sanderson':VARCHAR(17)), =($0, 'douglas adams':VARCHAR(17)))])",
-          "\n                          LogicalTableScan(table=[[b]])",
-          "\n"
-        ]
       }
     ]
   },
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 7dd9a546c2..1cc0bc5d19 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -19,17 +19,21 @@
 package org.apache.pinot.query.runtime.operator;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.stream.Collectors;
 import javax.annotation.Nullable;
 import org.apache.pinot.common.datablock.DataBlock;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.request.context.FunctionContext;
 import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.core.common.BlockValSet;
+import org.apache.pinot.core.common.IntermediateStageBlockValSet;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunctionFactory;
 import org.apache.pinot.query.planner.logical.RexExpression;
@@ -71,7 +75,7 @@ public class AggregateOperator extends MultiStageOperator {
 
   // Map that maintains the mapping between columnName and the column ordinal index. It is used to fetch the required
   // column value from row-based container and fetch the input datatype for the column.
-  private final HashMap<String, Integer> _colNameToIndexMap;
+  private final Map<String, Integer> _colNameToIndexMap;
 
   private TransferableBlock _upstreamErrorBlock;
   private boolean _readyToConstruct;
@@ -112,10 +116,12 @@ public class AggregateOperator extends MultiStageOperator {
     // Initialize the appropriate executor.
     if (!groupSet.isEmpty()) {
       _isGroupByAggregation = true;
-      _groupByExecutor = new MultistageGroupByExecutor(groupByExpr, aggFunctions, aggType, _colNameToIndexMap);
+      _groupByExecutor = new MultistageGroupByExecutor(groupByExpr, aggFunctions, aggType, _colNameToIndexMap,
+          _resultSchema);
     } else {
       _isGroupByAggregation = false;
-      _aggregationExecutor = new MultistageAggregationExecutor(aggFunctions, aggType, _colNameToIndexMap);
+      _aggregationExecutor = new MultistageAggregationExecutor(aggFunctions, aggType, _colNameToIndexMap,
+          _resultSchema);
     }
   }
 
@@ -267,4 +273,39 @@ public class AggregateOperator extends MultiStageOperator {
 
     return exprContext;
   }
+
+  static Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggFunction,
+      TransferableBlock block, DataSchema inputDataSchema, Map<String, Integer> colNameToIndexMap) {
+    List<ExpressionContext> expressions = aggFunction.getInputExpressions();
+    int numExpressions = expressions.size();
+    if (numExpressions == 0) {
+      return Collections.emptyMap();
+    }
+
+    Preconditions.checkState(numExpressions == 1, "Cannot handle more than one identifier in aggregation function.");
+    ExpressionContext expression = expressions.get(0);
+    Preconditions.checkState(expression.getType().equals(ExpressionContext.Type.IDENTIFIER));
+    int index = colNameToIndexMap.get(expression.getIdentifier());
+
+    DataSchema.ColumnDataType dataType = inputDataSchema.getColumnDataType(index);
+    Preconditions.checkState(block.getType().equals(DataBlock.Type.ROW), "Datablock type is not ROW");
+    // TODO: If the previous block is not mailbox received, this method is not efficient.  Then getDataBlock() will
+    //  convert the unserialized format to serialized format of BaseDataBlock. Then it will convert it back to column
+    //  value primitive type.
+    return Collections.singletonMap(expression,
+        new IntermediateStageBlockValSet(dataType, block.getDataBlock(), index));
+  }
+
+  static Object extractValueFromRow(AggregationFunction aggregationFunction, Object[] row,
+      Map<String, Integer> colNameToIndexMap) {
+    List<ExpressionContext> expressions = aggregationFunction.getInputExpressions();
+    Preconditions.checkState(expressions.size() == 1, "Only support single expression, got: %s", expressions.size());
+    ExpressionContext expr = expressions.get(0);
+    ExpressionContext.Type exprType = expr.getType();
+    if (exprType == ExpressionContext.Type.IDENTIFIER) {
+      return row[colNameToIndexMap.get(expr.getIdentifier())];
+    }
+    Preconditions.checkState(exprType == ExpressionContext.Type.LITERAL, "Unsupported expression type: %s", exprType);
+    return expr.getLiteral().getValue();
+  }
 }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/FilterOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/FilterOperator.java
index 2851aa528a..d2ec50129d 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/FilterOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/FilterOperator.java
@@ -28,7 +28,7 @@ import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
 import org.apache.pinot.query.runtime.operator.operands.TransformOperand;
-import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils;
+import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -100,7 +100,7 @@ public class FilterOperator extends MultiStageOperator {
     List<Object[]> resultRows = new ArrayList<>();
     List<Object[]> container = block.getContainer();
     for (Object[] row : container) {
-      if ((Boolean) FunctionInvokeUtils.convert(_filterOperand.apply(row), DataSchema.ColumnDataType.BOOLEAN)) {
+      if ((Boolean) TypeUtils.convert(_filterOperand.apply(row), DataSchema.ColumnDataType.BOOLEAN)) {
         resultRows.add(row);
       }
     }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
index 61798f3dc7..8d5e5b1a06 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/HashJoinOperator.java
@@ -38,7 +38,7 @@ import org.apache.pinot.query.planner.plannode.JoinNode;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
 import org.apache.pinot.query.runtime.operator.operands.TransformOperand;
-import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils;
+import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -245,7 +245,7 @@ public class HashJoinOperator extends MultiStageOperator {
             // TODO: Optimize this to avoid unnecessary object copy.
             Object[] resultRow = joinRow(leftRow, rightRow);
             if (_joinClauseEvaluators.isEmpty() || _joinClauseEvaluators.stream().allMatch(
-                evaluator -> (Boolean) FunctionInvokeUtils.convert(evaluator.apply(resultRow),
+                evaluator -> (Boolean) TypeUtils.convert(evaluator.apply(resultRow),
                     DataSchema.ColumnDataType.BOOLEAN))) {
               rows.add(resultRow);
               hasMatchForLeftRow = true;
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperator.java
index 64e9117a92..1f7069ea45 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/LeafStageTransferableBlockOperator.java
@@ -43,6 +43,7 @@ import org.apache.pinot.core.operator.blocks.results.SelectionResultsBlock;
 import org.apache.pinot.core.query.request.ServerQueryRequest;
 import org.apache.pinot.core.query.selection.SelectionOperatorUtils;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
+import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -254,12 +255,12 @@ public class LeafStageTransferableBlockOperator extends MultiStageOperator {
     List<Object[]> extractedRows = new ArrayList<>(resultRows.size());
     if (resultRows instanceof List) {
       for (Object[] row : resultRows) {
-        extractedRows.add(canonicalizeRow(row, desiredDataSchema, columnIndices));
+        extractedRows.add(TypeUtils.canonicalizeRow(row, desiredDataSchema, columnIndices));
       }
     } else if (resultRows instanceof PriorityQueue) {
       PriorityQueue<Object[]> priorityQueue = (PriorityQueue<Object[]>) resultRows;
       while (!priorityQueue.isEmpty()) {
-        extractedRows.add(canonicalizeRow(priorityQueue.poll(), desiredDataSchema, columnIndices));
+        extractedRows.add(TypeUtils.canonicalizeRow(priorityQueue.poll(), desiredDataSchema, columnIndices));
       }
     }
     return new TransferableBlock(extractedRows, desiredDataSchema, DataBlock.Type.ROW);
@@ -277,12 +278,12 @@ public class LeafStageTransferableBlockOperator extends MultiStageOperator {
     List<Object[]> extractedRows = new ArrayList<>(resultRows.size());
     if (resultRows instanceof List) {
       for (Object[] orgRow : resultRows) {
-        extractedRows.add(canonicalizeRow(orgRow, desiredDataSchema));
+        extractedRows.add(TypeUtils.canonicalizeRow(orgRow, desiredDataSchema));
       }
     } else if (resultRows instanceof PriorityQueue) {
       PriorityQueue<Object[]> priorityQueue = (PriorityQueue<Object[]>) resultRows;
       while (!priorityQueue.isEmpty()) {
-        extractedRows.add(canonicalizeRow(priorityQueue.poll(), desiredDataSchema));
+        extractedRows.add(TypeUtils.canonicalizeRow(priorityQueue.poll(), desiredDataSchema));
       }
     } else {
       throw new UnsupportedOperationException("Unsupported collection type: " + resultRows.getClass());
@@ -299,41 +300,6 @@ public class LeafStageTransferableBlockOperator extends MultiStageOperator {
     return true;
   }
 
-  /**
-   * This util is used to canonicalize row generated from V1 engine, which is stored using
-   * {@link DataSchema#getStoredColumnDataTypes()} format. However, the transferable block ser/de stores data in the
-   * {@link DataSchema#getColumnDataTypes()} format.
-   *
-   * @param row un-canonicalize row.
-   * @param dataSchema data schema desired for the row.
-   * @return canonicalize row.
-   */
-  private static Object[] canonicalizeRow(Object[] row, DataSchema dataSchema) {
-    Object[] resultRow = new Object[row.length];
-    for (int colId = 0; colId < row.length; colId++) {
-      Object value = row[colId];
-      if (value != null) {
-        if (dataSchema.getColumnDataType(colId) == DataSchema.ColumnDataType.OBJECT) {
-          resultRow[colId] = value;
-        } else {
-          resultRow[colId] = dataSchema.getColumnDataType(colId).convert(value);
-        }
-      }
-    }
-    return resultRow;
-  }
-
-  private static Object[] canonicalizeRow(Object[] row, DataSchema dataSchema, int[] columnIndices) {
-    Object[] resultRow = new Object[columnIndices.length];
-    for (int colId = 0; colId < columnIndices.length; colId++) {
-      Object value = row[columnIndices[colId]];
-      if (value != null) {
-        resultRow[colId] = dataSchema.getColumnDataType(colId).convert(value);
-      }
-    }
-    return resultRow;
-  }
-
   private static boolean isDataSchemaColumnTypesCompatible(DataSchema.ColumnDataType[] desiredTypes,
       DataSchema.ColumnDataType[] givenTypes) {
     if (desiredTypes.length != givenTypes.length) {
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java
index 8ea4ab0a91..19e4f66cc6 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageAggregationExecutor.java
@@ -18,19 +18,17 @@
  */
 package org.apache.pinot.query.runtime.operator;
 
-import com.google.common.base.Preconditions;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
-import org.apache.pinot.common.datablock.DataBlock;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.common.BlockValSet;
-import org.apache.pinot.core.common.IntermediateStageBlockValSet;
 import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.query.planner.plannode.AggregateNode.AggType;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
+import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
 
 
 /**
@@ -41,23 +39,23 @@ public class MultistageAggregationExecutor {
   // The identifier operands for the aggregation function only store the column name. This map contains mapping
   // from column name to their index.
   private final Map<String, Integer> _colNameToIndexMap;
+  private final DataSchema _resultSchema;
 
   private final AggregationFunction[] _aggFunctions;
 
   // Result holders for each mode.
   private final AggregationResultHolder[] _aggregateResultHolder;
   private final Object[] _mergeResultHolder;
-  private final Object[] _finalResultHolder;
 
-  public MultistageAggregationExecutor(AggregationFunction[] aggFunctions, AggType aggType,
-      Map<String, Integer> colNameToIndexMap) {
+  public MultistageAggregationExecutor(AggregationFunction[] aggFunctions,
+      AggType aggType, Map<String, Integer> colNameToIndexMap, DataSchema resultSchema) {
     _aggFunctions = aggFunctions;
     _aggType = aggType;
     _colNameToIndexMap = colNameToIndexMap;
+    _resultSchema = resultSchema;
 
     _aggregateResultHolder = new AggregationResultHolder[aggFunctions.length];
     _mergeResultHolder = new Object[aggFunctions.length];
-    _finalResultHolder = new Object[aggFunctions.length];
 
     for (int i = 0; i < _aggFunctions.length; i++) {
       _aggregateResultHolder[i] = _aggFunctions[i].createAggregationResultHolder();
@@ -70,10 +68,8 @@ public class MultistageAggregationExecutor {
   public void processBlock(TransferableBlock block, DataSchema inputDataSchema) {
     if (!_aggType.isInputIntermediateFormat()) {
       processAggregate(block, inputDataSchema);
-    } else if (_aggType.isOutputIntermediateFormat()) {
-      processMerge(block);
     } else {
-      collectResult(block);
+      processMerge(block);
     }
   }
 
@@ -93,35 +89,37 @@ public class MultistageAggregationExecutor {
    * Fetches the result.
    */
   public List<Object[]> getResult() {
-    int numFunctions = _aggFunctions.length;
-    Object[] row = new Object[numFunctions];
-    for (int i = 0; i < numFunctions; i++) {
-      AggregationFunction func = _aggFunctions[i];
-      if (!_aggType.isInputIntermediateFormat()) {
-        Object intermediateResult = func.extractAggregationResult(_aggregateResultHolder[i]);
-        if (_aggType.isOutputIntermediateFormat()) {
-          row[i] = intermediateResult;
-        } else {
-          Object finalResult = func.extractFinalResult(intermediateResult);
-          row[i] = finalResult == null ? null : func.getFinalResultColumnType().convert(finalResult);
-        }
-      } else {
-        if (_aggType.isOutputIntermediateFormat()) {
-          row[i] = _mergeResultHolder[i];
-        } else {
-          Object finalResult = func.extractFinalResult(_finalResultHolder[i]);
-          row[i] = finalResult == null ? null : func.getFinalResultColumnType().convert(finalResult);
-        }
+    Object[] row = new Object[_aggFunctions.length];
+    for (int i = 0; i < _aggFunctions.length; i++) {
+      AggregationFunction aggFunction = _aggFunctions[i];
+      Object value;
+      switch (_aggType) {
+        case LEAF:
+          value = aggFunction.extractAggregationResult(_aggregateResultHolder[i]);
+          break;
+        case INTERMEDIATE:
+          value = _mergeResultHolder[i];
+          break;
+        case FINAL:
+          value = aggFunction.extractFinalResult(_mergeResultHolder[i]);
+          break;
+        case DIRECT:
+          Object intermediate = aggFunction.extractAggregationResult(_aggregateResultHolder[i]);
+          value = aggFunction.extractFinalResult(intermediate);
+          break;
+        default:
+          throw new UnsupportedOperationException("Unsupported aggTyp: " + _aggType);
       }
+      row[i] = value;
     }
-    return Collections.singletonList(row);
+    return Collections.singletonList(TypeUtils.canonicalizeRow(row, _resultSchema));
   }
 
   private void processAggregate(TransferableBlock block, DataSchema inputDataSchema) {
     for (int i = 0; i < _aggFunctions.length; i++) {
       AggregationFunction aggregationFunction = _aggFunctions[i];
       Map<ExpressionContext, BlockValSet> blockValSetMap =
-          getBlockValSetMap(aggregationFunction, block, inputDataSchema);
+          AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap);
       aggregationFunction.aggregate(block.getNumRows(), _aggregateResultHolder[i], blockValSetMap);
     }
   }
@@ -131,7 +129,8 @@ public class MultistageAggregationExecutor {
 
     for (int i = 0; i < _aggFunctions.length; i++) {
       for (Object[] row : container) {
-        Object intermediateResultToMerge = extractValueFromRow(_aggFunctions[i], row);
+        Object intermediateResultToMerge =
+            AggregateOperator.extractValueFromRow(_aggFunctions[i], row, _colNameToIndexMap);
 
         // Not all V1 aggregation functions have null-handling logic. Handle null values before calling merge.
         if (intermediateResultToMerge == null) {
@@ -147,50 +146,4 @@ public class MultistageAggregationExecutor {
       }
     }
   }
-
-  private void collectResult(TransferableBlock block) {
-    List<Object[]> container = block.getContainer();
-    assert container.size() == 1;
-    Object[] row = container.get(0);
-    for (int i = 0; i < _aggFunctions.length; i++) {
-      _finalResultHolder[i] = extractValueFromRow(_aggFunctions[i], row);
-    }
-  }
-
-  private Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggFunction,
-      TransferableBlock block, DataSchema inputDataSchema) {
-    List<ExpressionContext> expressions = aggFunction.getInputExpressions();
-    int numExpressions = expressions.size();
-    if (numExpressions == 0) {
-      return Collections.emptyMap();
-    }
-
-    Preconditions.checkState(numExpressions == 1, "Cannot handle more than one identifier in aggregation function.");
-    ExpressionContext expression = expressions.get(0);
-    Preconditions.checkState(expression.getType().equals(ExpressionContext.Type.IDENTIFIER));
-    int index = _colNameToIndexMap.get(expression.getIdentifier());
-
-    DataSchema.ColumnDataType dataType = inputDataSchema.getColumnDataType(index);
-    Preconditions.checkState(block.getType().equals(DataBlock.Type.ROW), "Datablock type is not ROW");
-    // TODO: If the previous block is not mailbox received, this method is not efficient.  Then getDataBlock() will
-    //  convert the unserialized format to serialized format of BaseDataBlock. Then it will convert it back to column
-    //  value primitive type.
-    return Collections.singletonMap(expression,
-        new IntermediateStageBlockValSet(dataType, block.getDataBlock(), index));
-  }
-
-  private Object extractValueFromRow(AggregationFunction aggregationFunction, Object[] row) {
-    // TODO: Add support to handle aggregation functions where:
-    //       1. The identifier need not be the first argument
-    //       2. There are more than one identifiers.
-    List<ExpressionContext> expressions = aggregationFunction.getInputExpressions();
-    Preconditions.checkState(expressions.size() == 1, "Only support single expression, got: %s", expressions.size());
-    ExpressionContext expr = expressions.get(0);
-    ExpressionContext.Type exprType = expr.getType();
-    if (exprType == ExpressionContext.Type.IDENTIFIER) {
-      return row[_colNameToIndexMap.get(expr.getIdentifier())];
-    }
-    Preconditions.checkState(exprType == ExpressionContext.Type.LITERAL, "Unsupported expression type: %s", exprType);
-    return expr.getLiteral().getValue();
-  }
 }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
index 03d7638aa4..5eacba025b 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/MultistageGroupByExecutor.java
@@ -18,23 +18,21 @@
  */
 package org.apache.pinot.query.runtime.operator;
 
-import com.google.common.base.Preconditions;
 import java.util.ArrayList;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import org.apache.pinot.common.datablock.DataBlock;
 import org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.common.BlockValSet;
-import org.apache.pinot.core.common.IntermediateStageBlockValSet;
 import org.apache.pinot.core.data.table.Key;
 import org.apache.pinot.core.plan.maker.InstancePlanMakerImplV2;
 import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
 import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
 import org.apache.pinot.query.planner.plannode.AggregateNode.AggType;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
+import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
+
 
 
 /**
@@ -45,6 +43,7 @@ public class MultistageGroupByExecutor {
   // The identifier operands for the aggregation function only store the column name. This map contains mapping
   // between column name to their index which is used in v2 engine.
   private final Map<String, Integer> _colNameToIndexMap;
+  private final DataSchema _resultSchema;
 
   private final List<ExpressionContext> _groupSet;
   private final AggregationFunction[] _aggFunctions;
@@ -52,22 +51,21 @@ public class MultistageGroupByExecutor {
   // Group By Result holders for each mode
   private final GroupByResultHolder[] _aggregateResultHolders;
   private final Map<Integer, Object[]> _mergeResultHolder;
-  private final List<Object[]> _finalResultHolder;
 
   // Mapping from the row-key to a zero based integer index. This is used when we invoke the v1 aggregation functions
   // because they use the zero based integer indexes to store results.
   private final Map<Key, Integer> _groupKeyToIdMap;
 
   public MultistageGroupByExecutor(List<ExpressionContext> groupByExpr, AggregationFunction[] aggFunctions,
-      AggType aggType, Map<String, Integer> colNameToIndexMap) {
+      AggType aggType, Map<String, Integer> colNameToIndexMap, DataSchema resultSchema) {
     _aggType = aggType;
     _colNameToIndexMap = colNameToIndexMap;
     _groupSet = groupByExpr;
     _aggFunctions = aggFunctions;
+    _resultSchema = resultSchema;
 
     _aggregateResultHolders = new GroupByResultHolder[_aggFunctions.length];
     _mergeResultHolder = new HashMap<>();
-    _finalResultHolder = new ArrayList<>();
 
     _groupKeyToIdMap = new HashMap<>();
 
@@ -84,10 +82,8 @@ public class MultistageGroupByExecutor {
   public void processBlock(TransferableBlock block, DataSchema inputDataSchema) {
     if (!_aggType.isInputIntermediateFormat()) {
       processAggregate(block, inputDataSchema);
-    } else if (_aggType.isOutputIntermediateFormat()) {
-      processMerge(block);
     } else {
-      collectResult(block);
+      processMerge(block);
     }
   }
 
@@ -95,10 +91,6 @@ public class MultistageGroupByExecutor {
    * Fetches the result.
    */
   public List<Object[]> getResult() {
-    if (_aggType == AggType.FINAL) {
-      return extractFinalGroupByResult();
-    }
-
     List<Object[]> rows = new ArrayList<>(_groupKeyToIdMap.size());
     int numKeys = _groupSet.size();
     int numFunctions = _aggFunctions.length;
@@ -111,39 +103,27 @@ public class MultistageGroupByExecutor {
       for (int i = 0; i < numFunctions; i++) {
         AggregationFunction func = _aggFunctions[i];
         int index = numKeys + i;
-        if (!_aggType.isInputIntermediateFormat()) {
-          Object intermediateResult = func.extractGroupByResult(_aggregateResultHolders[i], groupId);
-          if (_aggType.isOutputIntermediateFormat()) {
-            row[index] = intermediateResult;
-          } else {
-            Object finalResult = func.extractFinalResult(intermediateResult);
-            row[index] = finalResult == null ? null : func.getFinalResultColumnType().convert(finalResult);
-          }
-        } else {
-          assert _aggType == AggType.INTERMEDIATE;
-          row[index] = _mergeResultHolder.get(groupId)[i];
+        Object value;
+        switch (_aggType) {
+          case LEAF:
+            value = func.extractGroupByResult(_aggregateResultHolders[i], groupId);
+            break;
+          case INTERMEDIATE:
+            value = _mergeResultHolder.get(groupId)[i];
+            break;
+          case FINAL:
+            value = func.extractFinalResult(_mergeResultHolder.get(groupId)[i]);
+            break;
+          case DIRECT:
+            Object intermediate = _aggFunctions[i].extractGroupByResult(_aggregateResultHolders[i], groupId);
+            value = func.extractFinalResult(intermediate);
+            break;
+          default:
+            throw new UnsupportedOperationException("Unsupported aggTyp: " + _aggType);
         }
+        row[index] = value;
       }
-      rows.add(row);
-    }
-    return rows;
-  }
-
-  private List<Object[]> extractFinalGroupByResult() {
-    List<Object[]> rows = new ArrayList<>(_finalResultHolder.size());
-    int numKeys = _groupSet.size();
-    int numFunctions = _aggFunctions.length;
-    int numColumns = numKeys + numFunctions;
-    for (Object[] resultRow : _finalResultHolder) {
-      Object[] row = new Object[numColumns];
-      System.arraycopy(resultRow, 0, row, 0, numKeys);
-      for (int i = 0; i < numFunctions; i++) {
-        AggregationFunction func = _aggFunctions[i];
-        int index = numKeys + i;
-        Object finalResult = func.extractFinalResult(resultRow[index]);
-        row[index] = finalResult == null ? null : func.getFinalResultColumnType().convert(finalResult);
-      }
-      rows.add(row);
+      rows.add(TypeUtils.canonicalizeRow(row, _resultSchema));
     }
     return rows;
   }
@@ -154,7 +134,7 @@ public class MultistageGroupByExecutor {
     for (int i = 0; i < _aggFunctions.length; i++) {
       AggregationFunction aggregationFunction = _aggFunctions[i];
       Map<ExpressionContext, BlockValSet> blockValSetMap =
-          getBlockValSetMap(aggregationFunction, block, inputDataSchema);
+          AggregateOperator.getBlockValSetMap(aggregationFunction, block, inputDataSchema, _colNameToIndexMap);
       GroupByResultHolder groupByResultHolder = _aggregateResultHolders[i];
       groupByResultHolder.ensureCapacity(_groupKeyToIdMap.size());
       aggregationFunction.aggregateGroupBySV(block.getNumRows(), intKeys, groupByResultHolder, blockValSetMap);
@@ -172,7 +152,8 @@ public class MultistageGroupByExecutor {
         if (!_mergeResultHolder.containsKey(rowKey)) {
           _mergeResultHolder.put(rowKey, new Object[_aggFunctions.length]);
         }
-        Object intermediateResultToMerge = extractValueFromRow(_aggFunctions[i], row);
+        Object intermediateResultToMerge =
+            AggregateOperator.extractValueFromRow(_aggFunctions[i], row, _colNameToIndexMap);
 
         // Not all V1 aggregation functions have null-handling. So handle null values and call merge only if necessary.
         if (intermediateResultToMerge == null) {
@@ -189,22 +170,6 @@ public class MultistageGroupByExecutor {
     }
   }
 
-  private void collectResult(TransferableBlock block) {
-    List<Object[]> container = block.getContainer();
-    for (Object[] row : container) {
-      assert row.length == _groupSet.size() + _aggFunctions.length;
-      Object[] resultRow = new Object[row.length];
-      System.arraycopy(row, 0, resultRow, 0, _groupSet.size());
-
-      for (int i = 0; i < _aggFunctions.length; i++) {
-        int index = _groupSet.size() + i;
-        resultRow[index] = extractValueFromRow(_aggFunctions[i], row);
-      }
-
-      _finalResultHolder.add(resultRow);
-    }
-  }
-
   /**
    * Creates the group by key for each row. Converts the key into a 0-index based int value that can be used by
    * GroupByAggregationResultHolders used in v1 aggregations.
@@ -226,41 +191,4 @@ public class MultistageGroupByExecutor {
     }
     return rowIntKeys;
   }
-
-  private Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggFunction,
-      TransferableBlock block, DataSchema inputDataSchema) {
-    List<ExpressionContext> expressions = aggFunction.getInputExpressions();
-    int numExpressions = expressions.size();
-    if (numExpressions == 0) {
-      return Collections.emptyMap();
-    }
-
-    Preconditions.checkState(numExpressions == 1, "Cannot handle more than one identifier in aggregation function.");
-    ExpressionContext expression = expressions.get(0);
-    Preconditions.checkState(expression.getType().equals(ExpressionContext.Type.IDENTIFIER));
-    int index = _colNameToIndexMap.get(expression.getIdentifier());
-
-    DataSchema.ColumnDataType dataType = inputDataSchema.getColumnDataType(index);
-    Preconditions.checkState(block.getType().equals(DataBlock.Type.ROW), "Datablock type is not ROW");
-    // TODO: If the previous block is not mailbox received, this method is not efficient.  Then getDataBlock() will
-    //  convert the unserialized format to serialized format of BaseDataBlock. Then it will convert it back to column
-    //  value primitive type.
-    return Collections.singletonMap(expression,
-        new IntermediateStageBlockValSet(dataType, block.getDataBlock(), index));
-  }
-
-  private Object extractValueFromRow(AggregationFunction aggregationFunction, Object[] row) {
-    // TODO: Add support to handle aggregation functions where:
-    //       1. The identifier need not be the first argument
-    //       2. There are more than one identifiers.
-    List<ExpressionContext> expressions = aggregationFunction.getInputExpressions();
-    Preconditions.checkState(expressions.size() == 1, "Only support single expression, got: %s", expressions.size());
-    ExpressionContext expr = expressions.get(0);
-    ExpressionContext.Type exprType = expr.getType();
-    if (exprType == ExpressionContext.Type.IDENTIFIER) {
-      return row[_colNameToIndexMap.get(expr.getIdentifier())];
-    }
-    Preconditions.checkState(exprType == ExpressionContext.Type.LITERAL, "Unsupported expression type: %s", exprType);
-    return expr.getLiteral().getValue();
-  }
 }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/TransformOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/TransformOperator.java
index e53bcfdac2..a913d95295 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/TransformOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/TransformOperator.java
@@ -29,7 +29,7 @@ import org.apache.pinot.query.planner.logical.RexExpression;
 import org.apache.pinot.query.runtime.blocks.TransferableBlock;
 import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
 import org.apache.pinot.query.runtime.operator.operands.TransformOperand;
-import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils;
+import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -113,8 +113,7 @@ public class TransformOperator extends MultiStageOperator {
     for (Object[] row : container) {
       Object[] resultRow = new Object[_resultColumnSize];
       for (int i = 0; i < _resultColumnSize; i++) {
-        resultRow[i] =
-            FunctionInvokeUtils.convert(_transformOperandsList.get(i).apply(row), _resultSchema.getColumnDataType(i));
+        resultRow[i] = TypeUtils.convert(_transformOperandsList.get(i).apply(row), _resultSchema.getColumnDataType(i));
       }
       resultRows.add(resultRow);
     }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java
index ab69ce67b2..dbe3c179ca 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java
@@ -24,7 +24,7 @@ import java.util.List;
 import java.util.function.IntPredicate;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.logical.RexExpression;
-import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils;
+import org.apache.pinot.query.runtime.operator.utils.TypeUtils;
 import org.apache.pinot.spi.utils.BooleanUtils;
 
 
@@ -160,8 +160,8 @@ public abstract class FilterOperand extends TransformOperand {
         return false;
       }
       if (_requireCasting) {
-        v1 = (Comparable) FunctionInvokeUtils.convert(v1, _commonCastType);
-        v2 = (Comparable) FunctionInvokeUtils.convert(v2, _commonCastType);
+        v1 = (Comparable) TypeUtils.convert(v1, _commonCastType);
+        v2 = (Comparable) TypeUtils.convert(v2, _commonCastType);
         assert v1 != null && v2 != null;
       }
       return _comparisonResultPredicate.test(v1.compareTo(v2));
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java
deleted file mode 100644
index 851748d2b2..0000000000
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.pinot.query.runtime.operator.utils;
-
-import javax.annotation.Nullable;
-import org.apache.pinot.common.utils.DataSchema;
-
-
-public class FunctionInvokeUtils {
-  private FunctionInvokeUtils() {
-  }
-
-  /**
-   * Convert result to the appropriate column data type according to the desired {@link DataSchema.ColumnDataType}
-   * of the {@link org.apache.pinot.core.common.Operator}.
-   *
-   * @param inputObj input entry
-   * @param columnDataType desired column data type
-   * @return converted entry
-   */
-  @Nullable
-  public static Object convert(@Nullable Object inputObj, DataSchema.ColumnDataType columnDataType) {
-    if (columnDataType.isNumber() && columnDataType != DataSchema.ColumnDataType.BIG_DECIMAL) {
-      return inputObj == null ? null : columnDataType.convert(inputObj);
-    } else {
-      return inputObj;
-    }
-  }
-}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/TypeUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/TypeUtils.java
new file mode 100644
index 0000000000..d60ad0bdb7
--- /dev/null
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/TypeUtils.java
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator.utils;
+
+import javax.annotation.Nullable;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.spi.utils.BooleanUtils;
+
+
+public class TypeUtils {
+  private TypeUtils() {
+  }
+
+  /**
+   * Convert result to the appropriate column data type according to the desired {@link DataSchema.ColumnDataType}
+   * of the {@link org.apache.pinot.core.common.Operator}.
+   *
+   * @param inputObj input entry
+   * @param columnDataType desired column data type
+   * @return converted entry
+   */
+  @Nullable
+  public static Object convert(@Nullable Object inputObj, DataSchema.ColumnDataType columnDataType) {
+    if (columnDataType.isNumber() && columnDataType != DataSchema.ColumnDataType.BIG_DECIMAL) {
+      return inputObj == null ? null : columnDataType.convert(inputObj);
+    } else {
+      return inputObj;
+    }
+  }
+
+  /**
+   * This util is used to canonicalize row generated from V1 engine, which is stored using
+   * {@link DataSchema#getStoredColumnDataTypes()} format. However, the transferable block ser/de stores data in the
+   * {@link DataSchema#getColumnDataTypes()} format.
+   *
+   * @param row un-canonicalize row.
+   * @param dataSchema data schema desired for the row.
+   * @return canonicalize row.
+   */
+  public static Object[] canonicalizeRow(Object[] row, DataSchema dataSchema) {
+    Object[] resultRow = new Object[row.length];
+    for (int colId = 0; colId < row.length; colId++) {
+      Object value = row[colId];
+      if (value != null) {
+        if (dataSchema.getColumnDataType(colId) == DataSchema.ColumnDataType.OBJECT) {
+          resultRow[colId] = value;
+        } else if (dataSchema.getColumnDataType(colId) == DataSchema.ColumnDataType.BOOLEAN) {
+          resultRow[colId] = BooleanUtils.toBoolean(value);
+        } else {
+          resultRow[colId] = dataSchema.getColumnDataType(colId).convert(value);
+        }
+      }
+    }
+    return resultRow;
+  }
+
+  /**
+   * Canonicalize rows with column indices not matching calcite order.
+   *
+   * see: {@link TypeUtils#canonicalizeRow(Object[], DataSchema)}
+   */
+  public static Object[] canonicalizeRow(Object[] row, DataSchema dataSchema, int[] columnIndices) {
+    Object[] resultRow = new Object[columnIndices.length];
+    for (int colId = 0; colId < columnIndices.length; colId++) {
+      Object value = row[columnIndices[colId]];
+      if (value != null) {
+        resultRow[colId] = dataSchema.getColumnDataType(colId).convert(value);
+      }
+    }
+    return resultRow;
+  }
+}
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
index 6ecb29efb4..8048bad5f4 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
@@ -41,6 +41,7 @@ import org.testng.annotations.Test;
 
 import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.DOUBLE;
 import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.INT;
+import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.LONG;
 import static org.apache.pinot.common.utils.DataSchema.ColumnDataType.STRING;
 
 public class AggregateOperatorTest {
@@ -74,8 +75,8 @@ public class AggregateOperatorTest {
     Mockito.when(_input.nextBlock())
         .thenReturn(TransferableBlockUtils.getErrorTransferableBlock(new Exception("foo!")));
 
-    DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, INT});
-    DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
+    DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, DOUBLE});
+    DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
     AggregateOperator operator =
         new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
             AggType.INTERMEDIATE);
@@ -97,10 +98,10 @@ public class AggregateOperatorTest {
     Mockito.when(_input.nextBlock()).thenReturn(TransferableBlockUtils.getEndOfStreamTransferableBlock());
 
     DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, INT});
-    DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
+    DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
     AggregateOperator operator =
         new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
-            AggType.INTERMEDIATE);
+            AggType.LEAF);
 
     // When:
     TransferableBlock block = operator.nextBlock();
@@ -121,10 +122,10 @@ public class AggregateOperatorTest {
         .thenReturn(TransferableBlockUtils.getNoOpTransferableBlock())
         .thenReturn(TransferableBlockUtils.getEndOfStreamTransferableBlock());
 
-    DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
+    DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
     AggregateOperator operator =
         new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
-            AggType.INTERMEDIATE);
+            AggType.LEAF);
 
     // When:
     TransferableBlock block1 = operator.nextBlock(); // build when reading NoOp block
@@ -142,11 +143,11 @@ public class AggregateOperatorTest {
     List<RexExpression> calls = ImmutableList.of(getSum(new RexExpression.InputRef(1)));
     List<RexExpression> group = ImmutableList.of(new RexExpression.InputRef(0));
 
-    DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, INT});
-    Mockito.when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inSchema, new Object[]{2, 1}))
+    DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, DOUBLE});
+    Mockito.when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inSchema, new Object[]{2, 1.0}))
         .thenReturn(TransferableBlockUtils.getEndOfStreamTransferableBlock());
 
-    DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
+    DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, DOUBLE});
     AggregateOperator operator =
         new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
             AggType.INTERMEDIATE);
@@ -158,25 +159,25 @@ public class AggregateOperatorTest {
     // Then:
     Mockito.verify(_input, Mockito.times(2)).nextBlock();
     Assert.assertTrue(block1.getNumRows() > 0, "First block is the result");
-    Assert.assertEquals(block1.getContainer().get(0), new Object[]{2, 1},
-        "Expected two columns (group by key, agg value)");
+    Assert.assertEquals(block1.getContainer().get(0), new Object[]{2, 1.0},
+        "Expected two columns (group by key, agg value), agg value is intermediate type");
     Assert.assertTrue(block2.isEndOfStreamBlock(), "Second block is EOS (done processing)");
   }
 
   @Test
   public void shouldAggregateSingleInputBlockWithLiteralInput() {
     // Given:
-    List<RexExpression> calls = ImmutableList.of(getSum(new RexExpression.Literal(FieldSpec.DataType.INT, 1)));
+    List<RexExpression> calls = ImmutableList.of(getSum(new RexExpression.Literal(FieldSpec.DataType.DOUBLE, 1.0)));
     List<RexExpression> group = ImmutableList.of(new RexExpression.InputRef(0));
 
-    DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, INT});
-    Mockito.when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inSchema, new Object[]{2, 3}))
+    DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, DOUBLE});
+    Mockito.when(_input.nextBlock()).thenReturn(OperatorTestUtil.block(inSchema, new Object[]{2, 3.0}))
         .thenReturn(TransferableBlockUtils.getEndOfStreamTransferableBlock());
 
-    DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
+    DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{INT, LONG});
     AggregateOperator operator =
         new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, inSchema, calls, group,
-            AggType.INTERMEDIATE);
+            AggType.FINAL);
 
     // When:
     TransferableBlock block1 = operator.nextBlock();
@@ -186,7 +187,7 @@ public class AggregateOperatorTest {
     Mockito.verify(_input, Mockito.times(2)).nextBlock();
     Assert.assertTrue(block1.getNumRows() > 0, "First block is the result");
     // second value is 1 (the literal) instead of 3 (the col val)
-    Assert.assertEquals(block1.getContainer().get(0), new Object[]{2, 1},
+    Assert.assertEquals(block1.getContainer().get(0), new Object[]{2, 1L},
         "Expected two columns (group by key, agg value)");
     Assert.assertTrue(block2.isEndOfStreamBlock(), "Second block is EOS (done processing)");
   }
@@ -196,9 +197,10 @@ public class AggregateOperatorTest {
     MultiStageOperator upstreamOperator = OperatorTestUtil.getOperator(OperatorTestUtil.OP_1);
     // Create an aggregation call with sum for first column and group by second column.
     RexExpression.FunctionCall agg = getSum(new RexExpression.InputRef(0));
-    DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, INT});
+    DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{STRING, INT});
+    DataSchema outSchema = new DataSchema(new String[]{"group", "sum"}, new ColumnDataType[]{STRING, DOUBLE});
     AggregateOperator sum0GroupBy1 = new AggregateOperator(OperatorTestUtil.getDefaultContext(), upstreamOperator,
-        OperatorTestUtil.getDataSchema(OperatorTestUtil.OP_1), inSchema, Collections.singletonList(agg),
+        outSchema, inSchema, Collections.singletonList(agg),
         Collections.singletonList(new RexExpression.InputRef(1)), AggType.LEAF);
     TransferableBlock result = sum0GroupBy1.getNextBlock();
     while (result.isNoOpBlock()) {
diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
index f377ca188e..70f5fcc6ef 100644
--- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
+++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
@@ -24,14 +24,11 @@ import java.util.List;
 import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
-import org.apache.calcite.config.CalciteSystemProperty;
 import org.apache.calcite.sql.SqlFunctionCategory;
-import org.apache.calcite.sql.SqlIdentifier;
 import org.apache.calcite.sql.SqlKind;
 import org.apache.calcite.sql.type.OperandTypes;
 import org.apache.calcite.sql.type.ReturnTypes;
 import org.apache.calcite.sql.type.SqlOperandTypeChecker;
-import org.apache.calcite.sql.type.SqlOperandTypeInference;
 import org.apache.calcite.sql.type.SqlReturnTypeInference;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.commons.lang.StringUtils;
@@ -50,32 +47,37 @@ import org.apache.pinot.spi.utils.CommonConstants;
  */
 public enum AggregationFunctionType {
   // Aggregation functions for single-valued columns
-  COUNT("count", Collections.emptyList(), true, null, SqlKind.COUNT, ReturnTypes.BIGINT, null,
-      CalciteSystemProperty.STRICT.value() ? OperandTypes.ANY : OperandTypes.ONE_OR_MORE, SqlFunctionCategory.NUMERIC,
-      null, ReturnTypes.BIGINT, null, null, null),
-  MIN("min", Collections.emptyList(), false, null, SqlKind.MIN, ReturnTypes.DOUBLE, null,
-      OperandTypes.COMPARABLE_ORDERED, SqlFunctionCategory.SYSTEM, null, ReturnTypes.DOUBLE, null, null, null),
-  MAX("max", Collections.emptyList(), false, null, SqlKind.MAX, ReturnTypes.DOUBLE, null,
-      OperandTypes.COMPARABLE_ORDERED, SqlFunctionCategory.SYSTEM, null, ReturnTypes.DOUBLE, null, null, null),
-  // In multistage SUM is reduced via the PinotAvgSumAggregateReduceFunctionsRule, need not set up anything for reduce
-  SUM("sum", Collections.emptyList(), false, null, SqlKind.SUM, ReturnTypes.DOUBLE, null, OperandTypes.NUMERIC,
-      SqlFunctionCategory.NUMERIC, null, ReturnTypes.DOUBLE, null, null, null),
-  SUM0("$sum0", Collections.emptyList(), false, null, SqlKind.SUM0, ReturnTypes.DOUBLE, null, OperandTypes.NUMERIC,
-      SqlFunctionCategory.NUMERIC, null, ReturnTypes.DOUBLE, null, null, null),
+  COUNT("count", null, SqlKind.COUNT, SqlFunctionCategory.NUMERIC, OperandTypes.ONE_OR_MORE,
+      ReturnTypes.explicit(SqlTypeName.BIGINT), ReturnTypes.explicit(SqlTypeName.BIGINT)),
+  // TODO: min/max only supports NUMERIC in Pinot, where Calcite supports COMPARABLE_ORDERED
+  MIN("min", null, SqlKind.MIN, SqlFunctionCategory.SYSTEM, OperandTypes.NUMERIC, ReturnTypes.ARG0_NULLABLE_IF_EMPTY,
+      ReturnTypes.explicit(SqlTypeName.DOUBLE)),
+  MAX("max", null, SqlKind.MAX, SqlFunctionCategory.SYSTEM, OperandTypes.NUMERIC, ReturnTypes.ARG0_NULLABLE_IF_EMPTY,
+      ReturnTypes.explicit(SqlTypeName.DOUBLE)),
+  SUM("sum", null, SqlKind.SUM, SqlFunctionCategory.NUMERIC, OperandTypes.NUMERIC, ReturnTypes.AGG_SUM,
+      ReturnTypes.explicit(SqlTypeName.DOUBLE)),
+  SUM0("$sum0", null, SqlKind.SUM0, SqlFunctionCategory.NUMERIC, OperandTypes.NUMERIC,
+      ReturnTypes.AGG_SUM_EMPTY_IS_ZERO, ReturnTypes.explicit(SqlTypeName.DOUBLE)),
   SUMPRECISION("sumPrecision"),
-  AVG("avg", Collections.emptyList(), true, null, SqlKind.AVG, ReturnTypes.AVG_AGG_FUNCTION, null,
-      OperandTypes.NUMERIC, SqlFunctionCategory.NUMERIC, null, null, "AVG_REDUCE", ReturnTypes.AVG_AGG_FUNCTION,
-      OperandTypes.NUMERIC_NUMERIC),
+  AVG("avg"),
   MODE("mode"),
+
   FIRSTWITHTIME("firstWithTime"),
   LASTWITHTIME("lastWithTime"),
   MINMAXRANGE("minMaxRange"),
-  DISTINCTCOUNT("distinctCount", Collections.emptyList(), false, null, SqlKind.OTHER_FUNCTION, ReturnTypes.BIGINT,
-      null, OperandTypes.ANY, SqlFunctionCategory.USER_DEFINED_FUNCTION, null,
-      ReturnTypes.explicit(SqlTypeName.OTHER), null, null, null),
+  /**
+   * for all distinct count family functions:
+   * (1) distinct_count only supports single argument;
+   * (2) count(distinct ...) support multi-argument and will be converted into DISTINCT + COUNT
+   */
+  DISTINCTCOUNT("distinctCount", null, SqlKind.OTHER_FUNCTION,
+      SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.ANY, ReturnTypes.BIGINT,
+      ReturnTypes.explicit(SqlTypeName.OTHER)),
   DISTINCTCOUNTBITMAP("distinctCountBitmap"),
   SEGMENTPARTITIONEDDISTINCTCOUNT("segmentPartitionedDistinctCount"),
-  DISTINCTCOUNTHLL("distinctCountHLL"),
+  DISTINCTCOUNTHLL("distinctCountHLL", Collections.emptyList(), SqlKind.OTHER_FUNCTION,
+      SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.ANY, ReturnTypes.BIGINT,
+      ReturnTypes.explicit(SqlTypeName.OTHER)),
   DISTINCTCOUNTRAWHLL("distinctCountRawHLL"),
   DISTINCTCOUNTSMARTHLL("distinctCountSmartHLL"),
   FASTHLL("fastHLL"),
@@ -83,6 +85,7 @@ public enum AggregationFunctionType {
   DISTINCTCOUNTRAWTHETASKETCH("distinctCountRawThetaSketch"),
   DISTINCTSUM("distinctSum"),
   DISTINCTAVG("distinctAvg"),
+
   PERCENTILE("percentile"),
   PERCENTILEEST("percentileEst"),
   PERCENTILERAWEST("percentileRawEst"),
@@ -91,20 +94,21 @@ public enum AggregationFunctionType {
   PERCENTILESMARTTDIGEST("percentileSmartTDigest"),
   PERCENTILEKLL("percentileKLL"),
   PERCENTILERAWKLL("percentileRawKLL"),
+
   IDSET("idSet"),
+
   HISTOGRAM("histogram"),
+
   COVARPOP("covarPop"),
   COVARSAMP("covarSamp"),
   VARPOP("varPop"),
   VARSAMP("varSamp"),
   STDDEVPOP("stdDevPop"),
   STDDEVSAMP("stdDevSamp"),
-  SKEWNESS("skewness", Collections.emptyList(), false, null, SqlKind.OTHER_FUNCTION, ReturnTypes.DOUBLE, null,
-      OperandTypes.NUMERIC, SqlFunctionCategory.USER_DEFINED_FUNCTION, "fourthMoment",
-      ReturnTypes.explicit(SqlTypeName.OTHER), null, null, null),
-  KURTOSIS("kurtosis", Collections.emptyList(), false, null, SqlKind.OTHER_FUNCTION, ReturnTypes.DOUBLE, null,
-      OperandTypes.NUMERIC, SqlFunctionCategory.USER_DEFINED_FUNCTION, "fourthMoment",
-      ReturnTypes.explicit(SqlTypeName.OTHER), null, null, null),
+  SKEWNESS("skewness", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
+      OperandTypes.NUMERIC, ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)),
+  KURTOSIS("kurtosis", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
+      OperandTypes.NUMERIC, ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)),
   FOURTHMOMENT("fourthMoment"),
 
   // DataSketches Tuple Sketch support
@@ -141,10 +145,10 @@ public enum AggregationFunctionType {
   PERCENTILERAWKLLMV("percentileRawKLLMV"),
 
   // boolean aggregate functions
-  BOOLAND("boolAnd", Collections.emptyList(), false, null, SqlKind.OTHER_FUNCTION, ReturnTypes.BOOLEAN, null,
-      OperandTypes.BOOLEAN, SqlFunctionCategory.USER_DEFINED_FUNCTION, null, ReturnTypes.INTEGER, null, null, null),
-  BOOLOR("boolOr", Collections.emptyList(), false, null, SqlKind.OTHER_FUNCTION, ReturnTypes.BOOLEAN, null,
-      OperandTypes.BOOLEAN, SqlFunctionCategory.USER_DEFINED_FUNCTION, null, ReturnTypes.INTEGER, null, null, null),
+  BOOLAND("boolAnd", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
+      OperandTypes.BOOLEAN, ReturnTypes.BOOLEAN, ReturnTypes.explicit(SqlTypeName.INTEGER)),
+  BOOLOR("boolOr", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
+      OperandTypes.BOOLEAN, ReturnTypes.BOOLEAN, ReturnTypes.explicit(SqlTypeName.INTEGER)),
 
   // argMin and argMax
   ARGMIN("argMin"),
@@ -160,61 +164,64 @@ public enum AggregationFunctionType {
   private static final Set<String> NAMES = Arrays.stream(values()).flatMap(func -> Stream.of(func.name(),
       func.getName(), func.getName().toLowerCase())).collect(Collectors.toSet());
 
+  // --------------------------------------------------------------------------
+  // Function signature used by Calcite.
+  // --------------------------------------------------------------------------
   private final String _name;
   private final List<String> _alternativeNames;
-  private final boolean _isNativeCalciteAggregationFunctionType;
 
   // Fields for registering the aggregation function with Calcite in multistage. These are typically used for the
   // user facing aggregation functions and the return and operand types should reflect that which is user facing.
-  private final SqlIdentifier _sqlIdentifier;
   private final SqlKind _sqlKind;
-  private final SqlReturnTypeInference _sqlReturnTypeInference;
-  private final SqlOperandTypeInference _sqlOperandTypeInference;
-  private final SqlOperandTypeChecker _sqlOperandTypeChecker;
   private final SqlFunctionCategory _sqlFunctionCategory;
 
-  // Fields for the intermediate stage functions used in multistage. These intermediate stage aggregation functions
-  // are typically internal to Pinot and are used to handle the complex types used in the intermediate and final
-  // aggregation stages. The intermediate function name may be the same as the user facing function name.
-  private final String _intermediateFunctionName;
-  private final SqlReturnTypeInference _sqlIntermediateReturnTypeInference;
-
-  // Fields for the Calcite aggregation reduce step in multistage. The aggregation reduce is only used for special
-  // functions like AVG which is split into SUM / COUNT. This is needed for proper null handling if COUNT is 0.
-  private final String _reduceFunctionName;
-  private final SqlReturnTypeInference _sqlReduceReturnTypeInference;
-  private final SqlOperandTypeChecker _sqlReduceOperandTypeChecker;
+  // override options for Pinot aggregate functions that expects different return type or operand type
+  private final SqlReturnTypeInference _returnTypeInference;
+  private final SqlOperandTypeChecker _operandTypeChecker;
+  // override options for Pinot aggregate rules to insert intermediate results that are non-standard than return type.
+  private final SqlReturnTypeInference _intermediateReturnTypeInference;
 
   /**
    * Constructor to use for aggregation functions which are only supported in v1 engine today
    */
   AggregationFunctionType(String name) {
-    this(name, Collections.emptyList(), false, null, null, null, null, null, null, null, null, null, null, null);
+    this(name, null, null, null);
   }
 
   /**
    * Constructor to use for aggregation functions which are supported in both v1 and multistage engines
    */
-  AggregationFunctionType(String name, List<String> alternativeNames, boolean isNativeCalciteAggregationFunctionType,
-      SqlIdentifier sqlIdentifier, SqlKind sqlKind, SqlReturnTypeInference sqlReturnTypeInference,
-      SqlOperandTypeInference sqlOperandTypeInference, SqlOperandTypeChecker sqlOperandTypeChecker,
-      SqlFunctionCategory sqlFunctionCategory, String intermediateFunctionName,
-      SqlReturnTypeInference sqlIntermediateReturnTypeInference, String reduceFunctionName,
-      SqlReturnTypeInference sqlReduceReturnTypeInference, SqlOperandTypeChecker sqlReduceOperandTypeChecker) {
+  AggregationFunctionType(String name, List<String> alternativeNames, SqlKind sqlKind,
+      SqlFunctionCategory sqlFunctionCategory) {
+    this(name, alternativeNames, sqlKind, sqlFunctionCategory, null, null, null);
+  }
+
+  /**
+   * Constructor to use for aggregation functions which are supported in both v1 and multistage engines
+   * and requires override on calcite behaviors.
+   */
+  AggregationFunctionType(String name, List<String> alternativeNames,
+      SqlKind sqlKind, SqlFunctionCategory sqlFunctionCategory, SqlOperandTypeChecker operandTypeChecker,
+      SqlReturnTypeInference returnTypeInference) {
+    this(name, alternativeNames, sqlKind, sqlFunctionCategory, operandTypeChecker, returnTypeInference, null);
+  }
+
+  AggregationFunctionType(String name, List<String> alternativeNames,
+      SqlKind sqlKind, SqlFunctionCategory sqlFunctionCategory, SqlOperandTypeChecker operandTypeChecker,
+      SqlReturnTypeInference returnTypeInference, SqlReturnTypeInference intermediateReturnTypeInference) {
     _name = name;
-    _alternativeNames = alternativeNames;
-    _isNativeCalciteAggregationFunctionType = isNativeCalciteAggregationFunctionType;
-    _sqlIdentifier = sqlIdentifier;
+    if (alternativeNames == null || alternativeNames.size() == 0) {
+      _alternativeNames = Collections.singletonList(getUnderscoreSplitAggregationFunctionName(_name));
+    } else {
+      _alternativeNames = alternativeNames;
+    }
     _sqlKind = sqlKind;
-    _sqlReturnTypeInference = sqlReturnTypeInference;
-    _sqlOperandTypeInference = sqlOperandTypeInference;
-    _sqlOperandTypeChecker = sqlOperandTypeChecker;
     _sqlFunctionCategory = sqlFunctionCategory;
-    _intermediateFunctionName = intermediateFunctionName == null ? name : intermediateFunctionName;
-    _sqlIntermediateReturnTypeInference = sqlIntermediateReturnTypeInference;
-    _reduceFunctionName = reduceFunctionName;
-    _sqlReduceReturnTypeInference = sqlReduceReturnTypeInference;
-    _sqlReduceOperandTypeChecker = sqlReduceOperandTypeChecker;
+
+    _returnTypeInference = returnTypeInference;
+    _operandTypeChecker = operandTypeChecker;
+    _intermediateReturnTypeInference = intermediateReturnTypeInference == null ? _returnTypeInference
+        : intermediateReturnTypeInference;
   }
 
   public String getName() {
@@ -225,54 +232,26 @@ public enum AggregationFunctionType {
     return _alternativeNames;
   }
 
-  public boolean isNativeCalciteAggregationFunctionType() {
-    return _isNativeCalciteAggregationFunctionType;
-  }
-
-  public SqlIdentifier getSqlIdentifier() {
-    return _sqlIdentifier;
-  }
-
   public SqlKind getSqlKind() {
     return _sqlKind;
   }
 
-  public SqlReturnTypeInference getSqlReturnTypeInference() {
-    return _sqlReturnTypeInference;
+  public SqlReturnTypeInference getIntermediateReturnTypeInference() {
+    return _intermediateReturnTypeInference;
   }
 
-  public SqlOperandTypeInference getSqlOperandTypeInference() {
-    return _sqlOperandTypeInference;
+  public SqlReturnTypeInference getReturnTypeInference() {
+    return _returnTypeInference;
   }
 
-  public SqlOperandTypeChecker getSqlOperandTypeChecker() {
-    return _sqlOperandTypeChecker;
+  public SqlOperandTypeChecker getOperandTypeChecker() {
+    return _operandTypeChecker;
   }
 
   public SqlFunctionCategory getSqlFunctionCategory() {
     return _sqlFunctionCategory;
   }
 
-  public String getIntermediateFunctionName() {
-    return _intermediateFunctionName;
-  }
-
-  public SqlReturnTypeInference getSqlIntermediateReturnTypeInference() {
-    return _sqlIntermediateReturnTypeInference;
-  }
-
-  public String getReduceFunctionName() {
-    return _reduceFunctionName;
-  }
-
-  public SqlReturnTypeInference getSqlReduceReturnTypeInference() {
-    return _sqlReduceReturnTypeInference;
-  }
-
-  public SqlOperandTypeChecker getSqlReduceOperandTypeChecker() {
-    return _sqlReduceOperandTypeChecker;
-  }
-
   public static boolean isAggregationFunction(String functionName) {
     if (NAMES.contains(functionName)) {
       return true;
@@ -293,6 +272,13 @@ public enum AggregationFunctionType {
     return StringUtils.remove(StringUtils.remove(functionName, '_').toUpperCase(), "$");
   }
 
+  public static String getUnderscoreSplitAggregationFunctionName(String functionName) {
+    // Skip functions that have numbers for now and return their name as is
+    return functionName.matches(".*\\d.*")
+        ? functionName
+        : functionName.replaceAll("(.)(\\p{Upper}+|\\d+)", "$1_$2");
+  }
+
   /**
    * Returns the corresponding aggregation function type for the given function name.
    * <p>NOTE: Underscores in the function name are ignored.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org